Skip to content

Commit 3e52ce9

Browse files
authored
Add random_gaussian_blur layer (#20817)
* Add random_gaussian_blur * Update description and add test cases * Correct failed test case
1 parent 734cd03 commit 3e52ce9

File tree

5 files changed

+355
-0
lines changed

5 files changed

+355
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@
178178
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
179179
RandomFlip,
180180
)
181+
from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (
182+
RandomGaussianBlur,
183+
)
181184
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
182185
RandomGrayscale,
183186
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@
178178
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
179179
RandomFlip,
180180
)
181+
from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (
182+
RandomGaussianBlur,
183+
)
181184
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
182185
RandomGrayscale,
183186
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@
122122
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
123123
RandomFlip,
124124
)
125+
from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (
126+
RandomGaussianBlur,
127+
)
125128
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
126129
RandomGrayscale,
127130
)
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomGaussianBlur")
9+
class RandomGaussianBlur(BaseImagePreprocessingLayer):
10+
"""Applies random Gaussian blur to images for data augmentation.
11+
12+
This layer performs a Gaussian blur operation on input images with a
13+
randomly selected degree of blurring, controlled by the `factor` and
14+
`sigma` arguments.
15+
16+
Args:
17+
factor: A single float or a tuple of two floats.
18+
`factor` controls the extent to which the image hue is impacted.
19+
`factor=0.0` makes this layer perform a no-op operation,
20+
while a value of `1.0` performs the most aggressive
21+
blurring available. If a tuple is used, a `factor` is
22+
sampled between the two values for every image augmented. If a
23+
single float is used, a value between `0.0` and the passed float is
24+
sampled. Default is 1.0.
25+
kernel_size: Integer. Size of the Gaussian kernel used for blurring.
26+
Must be an odd integer. Default is 3.
27+
sigma: Float or tuple of two floats. Standard deviation of the Gaussian
28+
kernel. Controls the intensity of the blur. If a tuple is provided,
29+
a value is sampled between the two for each image. Default is 1.0.
30+
value_range: the range of values the incoming images will have.
31+
Represented as a two-number tuple written `[low, high]`. This is
32+
typically either `[0, 1]` or `[0, 255]` depending on how your
33+
preprocessing pipeline is set up.
34+
seed: Integer. Used to create a random seed.
35+
"""
36+
37+
_USE_BASE_FACTOR = False
38+
_FACTOR_BOUNDS = (0, 1)
39+
40+
def __init__(
41+
self,
42+
factor=1.0,
43+
kernel_size=3,
44+
sigma=1.0,
45+
value_range=(0, 255),
46+
data_format=None,
47+
seed=None,
48+
**kwargs,
49+
):
50+
super().__init__(data_format=data_format, **kwargs)
51+
self._set_factor(factor)
52+
self.kernel_size = self._set_kernel_size(kernel_size, "kernel_size")
53+
self.sigma = self._set_factor_by_name(sigma, "sigma")
54+
self.value_range = value_range
55+
self.seed = seed
56+
self.generator = SeedGenerator(seed)
57+
58+
def _set_kernel_size(self, factor, name):
59+
error_msg = f"{name} must be an odd number. Received: {name}={factor}"
60+
if isinstance(factor, (tuple, list)):
61+
if len(factor) != 2:
62+
error_msg = (
63+
f"The `{name}` argument should be a number "
64+
"(or a list of two numbers) "
65+
f"Received: {name}={factor}"
66+
)
67+
raise ValueError(error_msg)
68+
if (factor[0] % 2 == 0) or (factor[1] % 2 == 0):
69+
raise ValueError(error_msg)
70+
lower, upper = factor
71+
elif isinstance(factor, (int, float)):
72+
if factor % 2 == 0:
73+
raise ValueError(error_msg)
74+
lower, upper = factor, factor
75+
else:
76+
raise ValueError(error_msg)
77+
78+
return lower, upper
79+
80+
def _set_factor_by_name(self, factor, name):
81+
error_msg = (
82+
f"The `{name}` argument should be a number "
83+
"(or a list of two numbers) "
84+
"in the range "
85+
f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. "
86+
f"Received: factor={factor}"
87+
)
88+
if isinstance(factor, (tuple, list)):
89+
if len(factor) != 2:
90+
raise ValueError(error_msg)
91+
if (
92+
factor[0] > self._FACTOR_BOUNDS[1]
93+
or factor[1] < self._FACTOR_BOUNDS[0]
94+
):
95+
raise ValueError(error_msg)
96+
lower, upper = sorted(factor)
97+
elif isinstance(factor, (int, float)):
98+
if (
99+
factor < self._FACTOR_BOUNDS[0]
100+
or factor > self._FACTOR_BOUNDS[1]
101+
):
102+
raise ValueError(error_msg)
103+
factor = abs(factor)
104+
lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]
105+
else:
106+
raise ValueError(error_msg)
107+
return lower, upper
108+
109+
def create_gaussian_kernel(self, kernel_size, sigma, num_channels):
110+
def get_gaussian_kernel1d(size, sigma):
111+
x = (
112+
self.backend.numpy.arange(size, dtype=self.compute_dtype)
113+
- (size - 1) / 2
114+
)
115+
kernel1d = self.backend.numpy.exp(-0.5 * (x / sigma) ** 2)
116+
return kernel1d / self.backend.numpy.sum(kernel1d)
117+
118+
def get_gaussian_kernel2d(size, sigma):
119+
kernel1d_x = get_gaussian_kernel1d(size[0], sigma[0])
120+
kernel1d_y = get_gaussian_kernel1d(size[1], sigma[1])
121+
return self.backend.numpy.tensordot(kernel1d_y, kernel1d_x, axes=0)
122+
123+
kernel = get_gaussian_kernel2d(kernel_size, sigma)
124+
125+
kernel = self.backend.numpy.reshape(
126+
kernel, (kernel_size[0], kernel_size[1], 1, 1)
127+
)
128+
kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1])
129+
130+
kernel = self.backend.cast(kernel, self.compute_dtype)
131+
132+
return kernel
133+
134+
def get_random_transformation(self, data, training=True, seed=None):
135+
if not training:
136+
return None
137+
138+
if isinstance(data, dict):
139+
images = data["images"]
140+
else:
141+
images = data
142+
143+
images_shape = self.backend.shape(images)
144+
rank = len(images_shape)
145+
if rank == 3:
146+
batch_size = 1
147+
elif rank == 4:
148+
batch_size = images_shape[0]
149+
else:
150+
raise ValueError(
151+
"Expected the input image to be rank 3 or 4. Received "
152+
f"inputs.shape={images_shape}"
153+
)
154+
155+
seed = seed or self._get_seed_generator(self.backend._backend)
156+
157+
blur_probability = self.backend.random.uniform(
158+
shape=(batch_size,),
159+
minval=self.factor[0],
160+
maxval=self.factor[1],
161+
seed=seed,
162+
)
163+
164+
random_threshold = self.backend.random.uniform(
165+
shape=(batch_size,),
166+
minval=0.0,
167+
maxval=1.0,
168+
seed=seed,
169+
)
170+
should_apply_blur = random_threshold < blur_probability
171+
172+
blur_factor = (
173+
self.backend.random.uniform(
174+
shape=(2,),
175+
minval=self.sigma[0],
176+
maxval=self.sigma[1],
177+
seed=seed,
178+
dtype=self.compute_dtype,
179+
)
180+
+ 1e-6
181+
)
182+
183+
return {
184+
"should_apply_blur": should_apply_blur,
185+
"blur_factor": blur_factor,
186+
}
187+
188+
def transform_images(self, images, transformation=None, training=True):
189+
images = self.backend.cast(images, self.compute_dtype)
190+
if training and transformation is not None:
191+
if self.data_format == "channels_first":
192+
images = self.backend.numpy.swapaxes(images, -3, -1)
193+
194+
blur_factor = transformation["blur_factor"]
195+
should_apply_blur = transformation["should_apply_blur"]
196+
197+
kernel = self.create_gaussian_kernel(
198+
self.kernel_size,
199+
blur_factor,
200+
self.backend.shape(images)[-1],
201+
)
202+
203+
blur_images = self.backend.nn.depthwise_conv(
204+
images,
205+
kernel,
206+
strides=1,
207+
padding="same",
208+
data_format="channels_last",
209+
)
210+
211+
images = self.backend.numpy.where(
212+
should_apply_blur[:, None, None, None],
213+
blur_images,
214+
images,
215+
)
216+
217+
images = self.backend.numpy.clip(
218+
images, self.value_range[0], self.value_range[1]
219+
)
220+
221+
if self.data_format == "channels_first":
222+
images = self.backend.numpy.swapaxes(images, -3, -1)
223+
224+
images = self.backend.cast(images, dtype=self.compute_dtype)
225+
226+
return images
227+
228+
def transform_labels(self, labels, transformation, training=True):
229+
return labels
230+
231+
def transform_segmentation_masks(
232+
self, segmentation_masks, transformation, training=True
233+
):
234+
return segmentation_masks
235+
236+
def transform_bounding_boxes(
237+
self, bounding_boxes, transformation, training=True
238+
):
239+
return bounding_boxes
240+
241+
def compute_output_shape(self, input_shape):
242+
return input_shape
243+
244+
def get_config(self):
245+
config = super().get_config()
246+
config.update(
247+
{
248+
"factor": self.factor,
249+
"kernel_size": self.kernel_size,
250+
"sigma": self.sigma,
251+
"value_range": self.value_range,
252+
"seed": self.seed,
253+
}
254+
)
255+
return config
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
from keras.src import backend
6+
from keras.src import layers
7+
from keras.src import testing
8+
from keras.src.backend import convert_to_tensor
9+
10+
11+
class RandomGaussianBlurTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomGaussianBlur,
16+
init_kwargs={
17+
"factor": 1.0,
18+
"kernel_size": 3,
19+
"sigma": 0,
20+
"value_range": (0, 255),
21+
"seed": 1,
22+
},
23+
input_shape=(8, 3, 4, 3),
24+
supports_masking=False,
25+
expected_output_shape=(8, 3, 4, 3),
26+
)
27+
28+
def test_random_erasing_inference(self):
29+
seed = 3481
30+
layer = layers.RandomGaussianBlur()
31+
32+
np.random.seed(seed)
33+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
34+
output = layer(inputs, training=False)
35+
self.assertAllClose(inputs, output)
36+
37+
def test_random_erasing_no_op(self):
38+
seed = 3481
39+
layer = layers.RandomGaussianBlur(factor=0)
40+
41+
np.random.seed(seed)
42+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
43+
output = layer(inputs)
44+
self.assertAllClose(inputs, output)
45+
46+
def test_random_erasing_basic(self):
47+
data_format = backend.config.image_data_format()
48+
if data_format == "channels_last":
49+
inputs = np.ones((1, 2, 2, 3))
50+
expected_output = np.asarray(
51+
[
52+
[
53+
[[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]],
54+
[[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]],
55+
]
56+
]
57+
)
58+
59+
else:
60+
inputs = np.ones((1, 3, 2, 2))
61+
expected_output = np.asarray(
62+
[
63+
[
64+
[[0.7273, 0.7273], [0.7273, 0.7273]],
65+
[[0.7273, 0.7273], [0.7273, 0.7273]],
66+
[[0.7273, 0.7273], [0.7273, 0.7273]],
67+
]
68+
]
69+
)
70+
71+
layer = layers.RandomGaussianBlur(data_format=data_format)
72+
73+
transformation = {
74+
"blur_factor": convert_to_tensor([0.3732, 0.8654]),
75+
"should_apply_blur": convert_to_tensor([True]),
76+
}
77+
output = layer.transform_images(inputs, transformation)
78+
79+
self.assertAllClose(expected_output, output, atol=1e-4, rtol=1e-4)
80+
81+
def test_tf_data_compatibility(self):
82+
data_format = backend.config.image_data_format()
83+
if data_format == "channels_last":
84+
input_data = np.random.random((2, 8, 8, 3))
85+
else:
86+
input_data = np.random.random((2, 3, 8, 8))
87+
layer = layers.RandomGaussianBlur(data_format=data_format)
88+
89+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
90+
for output in ds.take(1):
91+
output.numpy()

0 commit comments

Comments
 (0)