From 796d5edebe12090721554917902d34df0c060d73 Mon Sep 17 00:00:00 2001 From: Shouping Shan Date: Fri, 8 Oct 2021 01:06:18 +0800 Subject: [PATCH] [Fix] Fix bug when loading class name form file in custom dataset (#923) * [Fix] #916 expection string type classes * add unittests for string path classes * fix double quote string in test_dataset.py * move the import to the top of the file * fix isort lint error fix isort lint error when move the import to the top of the file --- mmseg/datasets/custom.py | 4 ++-- tests/test_data/test_dataset.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 23b347d..872b2b8 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -319,7 +319,7 @@ class CustomDataset(Dataset): raise ValueError(f'Unsupported type {type(classes)} of classes.') if self.CLASSES: - if not set(classes).issubset(self.CLASSES): + if not set(class_names).issubset(self.CLASSES): raise ValueError('classes is not a subset of CLASSES.') # dictionary, its keys are the old label ids and its values @@ -330,7 +330,7 @@ class CustomDataset(Dataset): if c not in class_names: self.label_map[i] = -1 else: - self.label_map[i] = classes.index(c) + self.label_map[i] = class_names.index(c) palette = self.get_palette_for_custom_classes(class_names, palette) diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index f1ce7bb..6524419 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os import os.path as osp import shutil +import tempfile from typing import Generator from unittest.mock import MagicMock, patch @@ -26,6 +28,37 @@ def test_classes(): get_classes('unsupported') +def test_classes_file_path(): + tmp_file = tempfile.NamedTemporaryFile() + classes_path = f'{tmp_file.name}.txt' + train_pipeline = [dict(type='LoadImageFromFile')] + kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path) + + # classes.txt with full categories + categories = get_classes('cityscapes') + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + assert list(CityscapesDataset(**kwargs).CLASSES) == categories + + # classes.txt with sub categories + categories = ['road', 'sidewalk', 'building'] + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + assert list(CityscapesDataset(**kwargs).CLASSES) == categories + + # classes.txt with unknown categories + categories = ['road', 'sidewalk', 'unknown'] + with open(classes_path, 'w') as f: + f.write('\n'.join(categories)) + + with pytest.raises(ValueError): + CityscapesDataset(**kwargs) + + tmp_file.close() + os.remove(classes_path) + assert not osp.exists(classes_path) + + def test_palette(): assert CityscapesDataset.PALETTE == get_palette('cityscapes') assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(