[Feature]add CLAHE transform (#229)

* add CLAHE transform

* fix syntax error

* fix syntax error

* restore

* add a test

* modify cv2 to mmcv

* add docstring

* modify

* restore

* fix mmcv.clahe error

* change mmcv version to 1.3.0

* fix bugs

* add all data transformers to __init__

* fix __init__

* fix test_transform
This commit is contained in:
yamengxi 2020-12-02 13:08:16 +08:00 committed by GitHub
parent 4dc809adf2
commit 0066ce884f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 6 deletions

View File

@ -3,14 +3,14 @@ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomFlip,
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'Rerange', 'RGB2Gray'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
]

View File

@ -1,6 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random
from ..builder import PIPELINES
@ -415,7 +415,6 @@ class Rerange(object):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
@ -439,6 +438,51 @@ class Rerange(object):
return repr_str
@PIPELINES.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size
def __call__(self, results):
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str
@PIPELINES.register_module()
class RandomCrop(object):
"""Random crop the image & seg.

View File

@ -409,6 +409,46 @@ def test_rerange():
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
build_from_cfg(transform, PIPELINES)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
build_from_cfg(transform, PIPELINES)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
build_from_cfg(transform, PIPELINES)
transform = dict(type='CLAHE', clip_limit=2)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
def test_seg_rescale():
results = dict()
seg = np.array(