[Feature]add CLAHE transform (#229)
* add CLAHE transform * fix syntax error * fix syntax error * restore * add a test * modify cv2 to mmcv * add docstring * modify * restore * fix mmcv.clahe error * change mmcv version to 1.3.0 * fix bugs * add all data transformers to __init__ * fix __init__ * fix test_transform
This commit is contained in:
parent
4dc809adf2
commit
0066ce884f
@ -3,14 +3,14 @@ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
from .loading import LoadAnnotations, LoadImageFromFile
|
||||
from .test_time_aug import MultiScaleFlipAug
|
||||
from .transforms import (Normalize, Pad, PhotoMetricDistortion, RandomCrop,
|
||||
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
|
||||
SegRescale)
|
||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
||||
PhotoMetricDistortion, RandomCrop, RandomFlip,
|
||||
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
||||
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||
'Rerange', 'RGB2Gray'
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
|
||||
]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||
from numpy import random
|
||||
|
||||
from ..builder import PIPELINES
|
||||
@ -415,7 +415,6 @@ class Rerange(object):
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Reranged results.
|
||||
"""
|
||||
@ -439,6 +438,51 @@ class Rerange(object):
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class CLAHE(object):
|
||||
"""Use CLAHE method to process the image.
|
||||
|
||||
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
||||
Graphics Gems, 1994:474-485.` for more information.
|
||||
|
||||
Args:
|
||||
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
||||
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
||||
Input image will be divided into equally sized rectangular tiles.
|
||||
It defines the number of tiles in row and column. Default: (8, 8).
|
||||
"""
|
||||
|
||||
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
|
||||
assert isinstance(clip_limit, (float, int))
|
||||
self.clip_limit = clip_limit
|
||||
assert is_tuple_of(tile_grid_size, int)
|
||||
assert len(tile_grid_size) == 2
|
||||
self.tile_grid_size = tile_grid_size
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to Use CLAHE method process images.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Processed results.
|
||||
"""
|
||||
|
||||
for i in range(results['img'].shape[2]):
|
||||
results['img'][:, :, i] = mmcv.clahe(
|
||||
np.array(results['img'][:, :, i], dtype=np.uint8),
|
||||
self.clip_limit, self.tile_grid_size)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(clip_limit={self.clip_limit}, '\
|
||||
f'tile_grid_size={self.tile_grid_size})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomCrop(object):
|
||||
"""Random crop the image & seg.
|
||||
|
||||
@ -409,6 +409,46 @@ def test_rerange():
|
||||
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
|
||||
|
||||
|
||||
def test_CLAHE():
|
||||
# test assertion if clip_limit is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', clip_limit=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='CLAHE', clip_limit=2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
converted_img = np.empty(original_img.shape)
|
||||
for i in range(original_img.shape[2]):
|
||||
converted_img[:, :, i] = mmcv.clahe(
|
||||
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
|
||||
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
|
||||
|
||||
|
||||
def test_seg_rescale():
|
||||
results = dict()
|
||||
seg = np.array(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user