From 3d1877511333d8db34103cc33ddc85a9fe214288 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Sat, 7 Nov 2020 09:53:12 -0800 Subject: [PATCH] [Feature] Add RandomRotate transform (#215) * add RandomRotate for transforms * change rotation function to mmcv.imrotate * refactor * add unittest * fixed test * fixed docstring * fixed test * add more test * fixed repr * rename to prob * fixed unittest Co-authored-by: hkzhang95 --- configs/_base_/datasets/ade20k.py | 2 +- configs/_base_/datasets/cityscapes.py | 2 +- configs/_base_/datasets/cityscapes_769x769.py | 2 +- configs/_base_/datasets/pascal_context.py | 2 +- configs/_base_/datasets/pascal_voc12.py | 2 +- mmseg/datasets/pipelines/transforms.py | 99 +++++++++++++++++-- tests/test_data/test_dataset.py | 2 +- tests/test_data/test_transform.py | 50 +++++++++- 8 files changed, 143 insertions(+), 18 deletions(-) diff --git a/configs/_base_/datasets/ade20k.py b/configs/_base_/datasets/ade20k.py index a1d9bab..efc8b4b 100644 --- a/configs/_base_/datasets/ade20k.py +++ b/configs/_base_/datasets/ade20k.py @@ -9,7 +9,7 @@ train_pipeline = [ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/configs/_base_/datasets/cityscapes.py b/configs/_base_/datasets/cityscapes.py index 21cf5c3..f21867c 100644 --- a/configs/_base_/datasets/cityscapes.py +++ b/configs/_base_/datasets/cityscapes.py @@ -9,7 +9,7 @@ train_pipeline = [ dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/configs/_base_/datasets/cityscapes_769x769.py b/configs/_base_/datasets/cityscapes_769x769.py index a5bcff3..336c7b2 100644 --- a/configs/_base_/datasets/cityscapes_769x769.py +++ b/configs/_base_/datasets/cityscapes_769x769.py @@ -7,7 +7,7 @@ train_pipeline = [ dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/configs/_base_/datasets/pascal_context.py b/configs/_base_/datasets/pascal_context.py index a00e474..ff65bad 100644 --- a/configs/_base_/datasets/pascal_context.py +++ b/configs/_base_/datasets/pascal_context.py @@ -12,7 +12,7 @@ train_pipeline = [ dict(type='LoadAnnotations'), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/configs/_base_/datasets/pascal_voc12.py b/configs/_base_/datasets/pascal_voc12.py index 6a367c7..ba1d42d 100644 --- a/configs/_base_/datasets/pascal_voc12.py +++ b/configs/_base_/datasets/pascal_voc12.py @@ -9,7 +9,7 @@ train_pipeline = [ dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 2b314a8..592854a 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1,5 +1,6 @@ import mmcv import numpy as np +from mmcv.utils import deprecated_api_warning from numpy import random from ..builder import PIPELINES @@ -232,16 +233,17 @@ class RandomFlip(object): method. Args: - flip_ratio (float, optional): The flipping probability. Default: None. + prob (float, optional): The flipping probability. Default: None. direction(str, optional): The flipping direction. Options are 'horizontal' and 'vertical'. Default: 'horizontal'. """ - def __init__(self, flip_ratio=None, direction='horizontal'): - self.flip_ratio = flip_ratio + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob self.direction = direction - if flip_ratio is not None: - assert flip_ratio >= 0 and flip_ratio <= 1 + if prob is not None: + assert prob >= 0 and prob <= 1 assert direction in ['horizontal', 'vertical'] def __call__(self, results): @@ -257,7 +259,7 @@ class RandomFlip(object): """ if 'flip' not in results: - flip = True if np.random.rand() < self.flip_ratio else False + flip = True if np.random.rand() < self.prob else False results['flip'] = flip if 'flip_direction' not in results: results['flip_direction'] = self.direction @@ -274,7 +276,7 @@ class RandomFlip(object): return results def __repr__(self): - return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + return self.__class__.__name__ + f'(prob={self.prob})' @PIPELINES.register_module() @@ -463,6 +465,89 @@ class RandomCrop(object): return self.__class__.__name__ + f'(crop_size={self.crop_size})' +@PIPELINES.register_module() +class RandomRotate(object): + """Rotate the image & seg. + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + def __call__(self, results): + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate = True if np.random.rand() < self.prob else False + degree = np.random.uniform(min(*self.degree), max(*self.degree)) + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + @PIPELINES.register_module() class SegRescale(object): """Rescale semantic segmentation maps. diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index d7e44f5..e933c20 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -69,7 +69,7 @@ def test_custom_dataset(): dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index 7a1ca0d..ff06375 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -94,18 +94,17 @@ def test_resize(): def test_flip(): - # test assertion for invalid flip_ratio + # test assertion for invalid prob with pytest.raises(AssertionError): - transform = dict(type='RandomFlip', flip_ratio=1.5) + transform = dict(type='RandomFlip', prob=1.5) build_from_cfg(transform, PIPELINES) # test assertion for invalid direction with pytest.raises(AssertionError): - transform = dict( - type='RandomFlip', flip_ratio=1, direction='horizonta') + transform = dict(type='RandomFlip', prob=1, direction='horizonta') build_from_cfg(transform, PIPELINES) - transform = dict(type='RandomFlip', flip_ratio=1) + transform = dict(type='RandomFlip', prob=1) flip_module = build_from_cfg(transform, PIPELINES) results = dict() @@ -197,6 +196,47 @@ def test_pad(): assert img_shape[1] % 32 == 0 +def test_rotate(): + # test assertion degree should be tuple[float] or float + with pytest.raises(AssertionError): + transform = dict(type='RandomRotate', prob=0.5, degree=-10) + build_from_cfg(transform, PIPELINES) + # test assertion degree should be tuple[float] or float + with pytest.raises(AssertionError): + transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.)) + build_from_cfg(transform, PIPELINES) + + transform = dict(type='RandomRotate', degree=10., prob=1.) + transform = build_from_cfg(transform, PIPELINES) + + assert str(transform) == f'RandomRotate(' \ + f'prob={1.}, ' \ + f'degree=({-10.}, {10.}), ' \ + f'pad_val={0}, ' \ + f'seg_pad_val={255}, ' \ + f'center={None}, ' \ + f'auto_bound={False})' + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + h, w, _ = img.shape + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + assert results['img'].shape[:2] == (h, w) + assert results['gt_semantic_seg'].shape[:2] == (h, w) + + def test_normalize(): img_norm_cfg = dict( mean=[123.675, 116.28, 103.53],