Skip to content

Commit 3de2270

Browse files
authored
Add random_erasing layer (#20798)
* Add initial random_erasing * Update random_erasing logic * Update description and add test case * fix value range bug * add seed for random fill_value
1 parent d7d2e43 commit 3de2270

File tree

5 files changed

+424
-0
lines changed

5 files changed

+424
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@
172172
from keras.src.layers.preprocessing.image_preprocessing.random_crop import (
173173
RandomCrop,
174174
)
175+
from keras.src.layers.preprocessing.image_preprocessing.random_erasing import (
176+
RandomErasing,
177+
)
175178
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
176179
RandomFlip,
177180
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@
172172
from keras.src.layers.preprocessing.image_preprocessing.random_crop import (
173173
RandomCrop,
174174
)
175+
from keras.src.layers.preprocessing.image_preprocessing.random_erasing import (
176+
RandomErasing,
177+
)
175178
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
176179
RandomFlip,
177180
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@
116116
from keras.src.layers.preprocessing.image_preprocessing.random_crop import (
117117
RandomCrop,
118118
)
119+
from keras.src.layers.preprocessing.image_preprocessing.random_erasing import (
120+
RandomErasing,
121+
)
119122
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
120123
RandomFlip,
121124
)
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomErasing")
9+
class RandomErasing(BaseImagePreprocessingLayer):
10+
"""Random Erasing data augmentation technique.
11+
12+
Random Erasing is a data augmentation method where random patches of
13+
an image are erased (replaced by a constant value or noise)
14+
during training to improve generalization.
15+
16+
Args:
17+
factor: A single float or a tuple of two floats.
18+
`factor` controls the extent to which the image hue is impacted.
19+
`factor=0.0` makes this layer perform a no-op operation,
20+
while a value of `1.0` performs the most aggressive
21+
erasing available. If a tuple is used, a `factor` is
22+
sampled between the two values for every image augmented. If a
23+
single float is used, a value between `0.0` and the passed float is
24+
sampled. Default is 1.0.
25+
scale: A tuple of two floats representing the aspect ratio range of
26+
the erased patch. This defines the width-to-height ratio of
27+
the patch to be erased. It can help control the rw shape of
28+
the erased region. Default is (0.02, 0.33).
29+
fill_value: A value to fill the erased region with. This can be set to
30+
a constant value or `None` to sample a random value
31+
from a normal distribution. Default is `None`.
32+
value_range: the range of values the incoming images will have.
33+
Represented as a two-number tuple written `[low, high]`. This is
34+
typically either `[0, 1]` or `[0, 255]` depending on how your
35+
preprocessing pipeline is set up.
36+
seed: Integer. Used to create a random seed.
37+
38+
References:
39+
- [Random Erasing paper](https://arxiv.org/abs/1708.04896).
40+
"""
41+
42+
_USE_BASE_FACTOR = False
43+
_FACTOR_BOUNDS = (0, 1)
44+
45+
def __init__(
46+
self,
47+
factor=1.0,
48+
scale=(0.02, 0.33),
49+
fill_value=None,
50+
value_range=(0, 255),
51+
seed=None,
52+
data_format=None,
53+
**kwargs,
54+
):
55+
super().__init__(data_format=data_format, **kwargs)
56+
self._set_factor(factor)
57+
self.scale = self._set_factor_by_name(scale, "scale")
58+
self.fill_value = fill_value
59+
self.value_range = value_range
60+
self.seed = seed
61+
self.generator = SeedGenerator(seed)
62+
63+
if self.data_format == "channels_first":
64+
self.height_axis = -2
65+
self.width_axis = -1
66+
self.channel_axis = -3
67+
else:
68+
self.height_axis = -3
69+
self.width_axis = -2
70+
self.channel_axis = -1
71+
72+
def _set_factor_by_name(self, factor, name):
73+
error_msg = (
74+
f"The `{name}` argument should be a number "
75+
"(or a list of two numbers) "
76+
"in the range "
77+
f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. "
78+
f"Received: factor={factor}"
79+
)
80+
if isinstance(factor, (tuple, list)):
81+
if len(factor) != 2:
82+
raise ValueError(error_msg)
83+
if (
84+
factor[0] > self._FACTOR_BOUNDS[1]
85+
or factor[1] < self._FACTOR_BOUNDS[0]
86+
):
87+
raise ValueError(error_msg)
88+
lower, upper = sorted(factor)
89+
elif isinstance(factor, (int, float)):
90+
if (
91+
factor < self._FACTOR_BOUNDS[0]
92+
or factor > self._FACTOR_BOUNDS[1]
93+
):
94+
raise ValueError(error_msg)
95+
factor = abs(factor)
96+
lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]
97+
else:
98+
raise ValueError(error_msg)
99+
return lower, upper
100+
101+
def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed):
102+
crop_length = self.backend.cast(
103+
crop_ratio * image_length, dtype=self.compute_dtype
104+
)
105+
106+
start_pos = self.backend.random.uniform(
107+
shape=[batch_size],
108+
minval=0,
109+
maxval=1,
110+
dtype=self.compute_dtype,
111+
seed=seed,
112+
) * (image_length - crop_length)
113+
114+
end_pos = start_pos + crop_length
115+
116+
return start_pos, end_pos
117+
118+
def _generate_batch_mask(self, images_shape, box_corners):
119+
def _generate_grid_xy(image_height, image_width):
120+
grid_y, grid_x = self.backend.numpy.meshgrid(
121+
self.backend.numpy.arange(
122+
image_height, dtype=self.compute_dtype
123+
),
124+
self.backend.numpy.arange(
125+
image_width, dtype=self.compute_dtype
126+
),
127+
indexing="ij",
128+
)
129+
if self.data_format == "channels_last":
130+
grid_y = self.backend.cast(
131+
grid_y[None, :, :, None], dtype=self.compute_dtype
132+
)
133+
grid_x = self.backend.cast(
134+
grid_x[None, :, :, None], dtype=self.compute_dtype
135+
)
136+
else:
137+
grid_y = self.backend.cast(
138+
grid_y[None, None, :, :], dtype=self.compute_dtype
139+
)
140+
grid_x = self.backend.cast(
141+
grid_x[None, None, :, :], dtype=self.compute_dtype
142+
)
143+
return grid_x, grid_y
144+
145+
image_height, image_width = (
146+
images_shape[self.height_axis],
147+
images_shape[self.width_axis],
148+
)
149+
grid_x, grid_y = _generate_grid_xy(image_height, image_width)
150+
151+
x0, x1, y0, y1 = box_corners
152+
153+
x0 = x0[:, None, None, None]
154+
y0 = y0[:, None, None, None]
155+
x1 = x1[:, None, None, None]
156+
y1 = y1[:, None, None, None]
157+
158+
batch_masks = (
159+
(grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1)
160+
)
161+
batch_masks = self.backend.numpy.repeat(
162+
batch_masks, images_shape[self.channel_axis], axis=self.channel_axis
163+
)
164+
165+
return batch_masks
166+
167+
def _get_fill_value(self, images, images_shape, seed):
168+
fill_value = self.fill_value
169+
if fill_value is None:
170+
fill_value = (
171+
self.backend.random.normal(
172+
images_shape,
173+
dtype=self.compute_dtype,
174+
seed=seed,
175+
)
176+
* self.value_range[1]
177+
)
178+
else:
179+
error_msg = (
180+
"The `fill_value` argument should be a number "
181+
"(or a list of three numbers) "
182+
)
183+
if isinstance(fill_value, (tuple, list)):
184+
if len(fill_value) != 3:
185+
raise ValueError(error_msg)
186+
fill_value = self.backend.numpy.full_like(
187+
images, fill_value, dtype=self.compute_dtype
188+
)
189+
elif isinstance(fill_value, (int, float)):
190+
fill_value = (
191+
self.backend.numpy.ones(
192+
images_shape, dtype=self.compute_dtype
193+
)
194+
* fill_value
195+
)
196+
else:
197+
raise ValueError(error_msg)
198+
fill_value = self.backend.numpy.clip(
199+
fill_value, self.value_range[0], self.value_range[1]
200+
)
201+
return fill_value
202+
203+
def get_random_transformation(self, data, training=True, seed=None):
204+
if not training:
205+
return None
206+
207+
if isinstance(data, dict):
208+
images = data["images"]
209+
else:
210+
images = data
211+
212+
images_shape = self.backend.shape(images)
213+
rank = len(images_shape)
214+
if rank == 3:
215+
batch_size = 1
216+
elif rank == 4:
217+
batch_size = images_shape[0]
218+
else:
219+
raise ValueError(
220+
"Expected the input image to be rank 3 or 4. Received "
221+
f"inputs.shape={images_shape}"
222+
)
223+
224+
image_height = images_shape[self.height_axis]
225+
image_width = images_shape[self.width_axis]
226+
227+
seed = seed or self._get_seed_generator(self.backend._backend)
228+
229+
mix_weight = self.backend.random.uniform(
230+
shape=(batch_size, 2),
231+
minval=self.scale[0],
232+
maxval=self.scale[1],
233+
dtype=self.compute_dtype,
234+
seed=seed,
235+
)
236+
237+
mix_weight = self.backend.numpy.sqrt(mix_weight)
238+
239+
x0, x1 = self._compute_crop_bounds(
240+
batch_size, image_width, mix_weight[:, 0], seed
241+
)
242+
y0, y1 = self._compute_crop_bounds(
243+
batch_size, image_height, mix_weight[:, 1], seed
244+
)
245+
246+
batch_masks = self._generate_batch_mask(
247+
images_shape,
248+
(x0, x1, y0, y1),
249+
)
250+
251+
erase_probability = self.backend.random.uniform(
252+
shape=(batch_size,),
253+
minval=self.factor[0],
254+
maxval=self.factor[1],
255+
seed=seed,
256+
)
257+
258+
random_threshold = self.backend.random.uniform(
259+
shape=(batch_size,),
260+
minval=0.0,
261+
maxval=1.0,
262+
seed=seed,
263+
)
264+
apply_erasing = random_threshold < erase_probability
265+
266+
fill_value = self._get_fill_value(images, images_shape, seed)
267+
268+
return {
269+
"apply_erasing": apply_erasing,
270+
"batch_masks": batch_masks,
271+
"fill_value": fill_value,
272+
}
273+
274+
def transform_images(self, images, transformation=None, training=True):
275+
if training:
276+
images = self.backend.cast(images, self.compute_dtype)
277+
batch_masks = transformation["batch_masks"]
278+
apply_erasing = transformation["apply_erasing"]
279+
fill_value = transformation["fill_value"]
280+
281+
erased_images = self.backend.numpy.where(
282+
batch_masks,
283+
fill_value,
284+
images,
285+
)
286+
287+
images = self.backend.numpy.where(
288+
apply_erasing[:, None, None, None],
289+
erased_images,
290+
images,
291+
)
292+
293+
images = self.backend.cast(images, self.compute_dtype)
294+
return images
295+
296+
def transform_labels(self, labels, transformation, training=True):
297+
return labels
298+
299+
def transform_bounding_boxes(
300+
self,
301+
bounding_boxes,
302+
transformation,
303+
training=True,
304+
):
305+
return bounding_boxes
306+
307+
def transform_segmentation_masks(
308+
self, segmentation_masks, transformation, training=True
309+
):
310+
return segmentation_masks
311+
312+
def compute_output_shape(self, input_shape):
313+
return input_shape
314+
315+
def get_config(self):
316+
config = {
317+
"factor": self.factor,
318+
"scale": self.scale,
319+
"fill_value": self.fill_value,
320+
"value_range": self.value_range,
321+
"seed": self.seed,
322+
}
323+
base_config = super().get_config()
324+
return {**base_config, **config}

0 commit comments

Comments
 (0)