# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import mmcv import numpy as np from mmcv.utils import print_log from PIL import Image from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class GolfDataset(CustomDataset): """GolfDataset for custom semantic segmentation with two classes: road and grass.""" CLASSES = ('road', 'grass') PALETTE = [[128, 64, 128], # road [0, 255, 0]] # grass def __init__(self, img_suffix='_leftImg8bit.png', seg_map_suffix='_gtFine_labelIds.png', **kwargs): super(GolfDataset, self).__init__( img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) def results2img(self, results, imgfile_prefix, indices=None): """Write the segmentation results to images.""" if indices is None: indices = list(range(len(self))) mmcv.mkdir_or_exist(imgfile_prefix) result_files = [] for result, idx in zip(results, indices): filename = self.img_infos[idx]['filename'] basename = osp.splitext(osp.basename(filename))[0] png_filename = osp.join(imgfile_prefix, f'{basename}.png') output = Image.fromarray(result.astype(np.uint8)).convert('P') palette = np.zeros((len(self.PALETTE), 3), dtype=np.uint8) for label_id, color in enumerate(self.PALETTE): palette[label_id] = color output.putpalette(palette) output.save(png_filename) result_files.append(png_filename) return result_files def format_results(self, results, imgfile_prefix, indices=None): """Format the results into dir (for evaluation or visualization).""" result_files = self.results2img(results, imgfile_prefix, indices) return result_files def evaluate(self, results, metric='mIoU', logger=None, imgfile_prefix=None): """Evaluate the results with the given metric.""" metrics = metric if isinstance(metric, list) else [metric] eval_results = super(GolfDataset, self).evaluate(results, metrics, logger) return eval_results