Skip to content

Commit ad517d4

Browse files
Yshuo-Li李尹硕
andauthored
Add down_sampling.py for generating LQ image from GT image. (open-mmlab#222)
* Add down_sampling.py for generating LQ image from GT image, which is required in LIIF. * Add '__repr__' and test_down_sampling.py. * Add docstring, rename parameter and change the function of resize. * Fine-tuning code and docstring of RandomDownSampling class. * Remove hardcode of bicubic and pillow. Co-authored-by: 李尹硕 <SENSETIME\liyinshuo@cn0014004493l.domain.sensetime.com>
1 parent 79f74a1 commit ad517d4

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed

mmedit/datasets/pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .compose import Compose
66
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
77
FixedCrop, ModCrop, PairedRandomCrop)
8+
from .down_sampling import RandomDownSampling
89
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
910
ToTensor)
1011
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
@@ -25,6 +26,6 @@
2526
'MergeFgAndBg', 'CompositeFg', 'TemporalReverse', 'LoadImageFromFileList',
2627
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
2728
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
28-
'CropAroundFg', 'GetSpatialDiscountMask',
29+
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
2930
'GenerateTrimapWithDistTransform', 'TransformTrimap'
3031
]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import math
2+
3+
import numpy as np
4+
import torch
5+
from mmcv import imresize
6+
7+
from ..registry import PIPELINES
8+
9+
10+
@PIPELINES.register_module()
11+
class RandomDownSampling:
12+
"""Generate LQ image from GT (and crop), which will randomly pick a scale.
13+
14+
Args:
15+
scale_min (float): The minimum of upsampling scale, inclusive.
16+
Default: 1.0.
17+
scale_max (float): The maximum of upsampling scale, exclusive.
18+
Default: 4.0.
19+
patch_size (int): The cropped lr patch size.
20+
Default: None, means no crop.
21+
interpolation (str): Interpolation method, accepted values are
22+
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
23+
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
24+
"hamming" for 'pillow' backend.
25+
Default: "bicubic".
26+
backend (str | None): The image resize backend type. Options are `cv2`,
27+
`pillow`, `None`. If backend is None, the global imread_backend
28+
specified by ``mmcv.use_backend()`` will be used.
29+
Default: "pillow".
30+
31+
Scale will be picked in the range of [scale_min, scale_max).
32+
"""
33+
34+
def __init__(self,
35+
scale_min=1.0,
36+
scale_max=4.0,
37+
patch_size=None,
38+
interpolation='bicubic',
39+
backend='pillow'):
40+
assert scale_max >= scale_min
41+
self.scale_min = scale_min
42+
self.scale_max = scale_max
43+
self.patch_size = patch_size
44+
self.interpolation = interpolation
45+
self.backend = backend
46+
47+
def __call__(self, results):
48+
"""Call function.
49+
50+
Args:
51+
results (dict): A dict containing the necessary information and
52+
data for augmentation. 'gt' is required.
53+
54+
Returns:
55+
dict: A dict containing the processed data and information.
56+
modified 'gt', supplement 'lq' and 'scale' to keys.
57+
"""
58+
img = results['gt']
59+
scale = np.random.uniform(self.scale_min, self.scale_max)
60+
61+
if self.patch_size is None:
62+
h_lr = math.floor(img.shape[-3] / scale + 1e-9)
63+
w_lr = math.floor(img.shape[-2] / scale + 1e-9)
64+
img = img[:round(h_lr * scale), :round(w_lr * scale), :]
65+
img_down = resize_fn(img, (w_lr, h_lr), self.interpolation,
66+
self.backend)
67+
crop_lr, crop_hr = img_down, img
68+
else:
69+
w_lr = self.patch_size
70+
w_hr = round(w_lr * scale)
71+
x0 = np.random.randint(0, img.shape[-3] - w_hr)
72+
y0 = np.random.randint(0, img.shape[-2] - w_hr)
73+
crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :]
74+
crop_lr = resize_fn(crop_hr, w_lr, self.interpolation,
75+
self.backend)
76+
results['gt'] = crop_hr
77+
results['lq'] = crop_lr
78+
results['scale'] = scale
79+
80+
return results
81+
82+
def __repr__(self):
83+
repr_str = self.__class__.__name__
84+
repr_str += (f'scale_min={self.scale_min}, '
85+
f'scale_max={self.scale_max}, '
86+
f'patch_size={self.patch_size}')
87+
88+
return repr_str
89+
90+
91+
def resize_fn(img, size, interpolation='bicubic', backend='pillow'):
92+
"""Resize the given image to a given size.
93+
94+
Args:
95+
img (ndarray | torch.Tensor): The input image.
96+
size (int | tuple[int]): Target size w or (w, h).
97+
interpolation (str): Interpolation method, accepted values are
98+
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
99+
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
100+
"hamming" for 'pillow' backend.
101+
Default: "bicubic".
102+
backend (str | None): The image resize backend type. Options are `cv2`,
103+
`pillow`, `None`. If backend is None, the global imread_backend
104+
specified by ``mmcv.use_backend()`` will be used.
105+
Default: "pillow".
106+
107+
Returns:
108+
ndarray | torch.Tensor: `resized_img`, whose type is same as `img`.
109+
"""
110+
if isinstance(size, int):
111+
size = (size, size)
112+
if isinstance(img, np.ndarray):
113+
return imresize(
114+
img, size, interpolation=interpolation, backend=backend)
115+
elif isinstance(img, torch.Tensor):
116+
image = imresize(
117+
img.numpy(), size, interpolation=interpolation, backend=backend)
118+
return torch.from_numpy(image)
119+
120+
else:
121+
raise TypeError('img should got np.ndarray or torch.Tensor,'
122+
f'but got {type(img)}')

tests/test_down_sampling.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
3+
from mmedit.datasets.pipelines import RandomDownSampling
4+
5+
6+
def test_down_sampling():
7+
img1 = np.uint8(np.random.randn(480, 640, 3) * 255)
8+
inputs1 = dict(gt=img1)
9+
down_sampling1 = RandomDownSampling(
10+
scale_min=1, scale_max=4, patch_size=None)
11+
results1 = down_sampling1(inputs1)
12+
assert set(list(results1.keys())) == set(['gt', 'lq', 'scale'])
13+
assert repr(down_sampling1) == (
14+
down_sampling1.__class__.__name__ +
15+
f'scale_min={down_sampling1.scale_min}, ' +
16+
f'scale_max={down_sampling1.scale_max}, ' +
17+
f'patch_size={down_sampling1.patch_size}')
18+
19+
img2 = np.uint8(np.random.randn(480, 640, 3) * 255)
20+
inputs2 = dict(gt=img2)
21+
down_sampling2 = RandomDownSampling(
22+
scale_min=1, scale_max=4, patch_size=48)
23+
results2 = down_sampling2(inputs2)
24+
assert set(list(results2.keys())) == set(['gt', 'lq', 'scale'])
25+
assert repr(down_sampling2) == (
26+
down_sampling2.__class__.__name__ +
27+
f'scale_min={down_sampling2.scale_min}, ' +
28+
f'scale_max={down_sampling2.scale_max}, ' +
29+
f'patch_size={down_sampling2.patch_size}')

0 commit comments

Comments
 (0)