Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
226 changes: 226 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/cut_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.CutMix")
class CutMix(BaseImagePreprocessingLayer):
"""CutMix data augmentation technique.

CutMix is a data augmentation method where patches are cut and pasted
between two images in the dataset, while the labels are also mixed
proportionally to the area of the patches.

Args:
factor: A single float or a tuple of two floats between 0 and 1.
If a tuple of numbers is passed, a `factor` is sampled
between the two values.
If a single float is passed, a value between 0 and the passed
float is sampled. These values define the range from which the
mixing weight is sampled. A higher factor increases the variability
in patch sizes, leading to more diverse and larger mixed patches.
Defaults to 1.
seed: Integer. Used to create a random seed.

References:
- [CutMix paper]( https://arxiv.org/abs/1905.04899).
"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

def __init__(self, factor=1.0, seed=None, data_format=None, **kwargs):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self.seed = seed
self.generator = SeedGenerator(seed)

if self.data_format == "channels_first":
self.height_axis = -2
self.width_axis = -1
self.channel_axis = -3
else:
self.height_axis = -3
self.width_axis = -2
self.channel_axis = -1

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

if isinstance(data, dict):
images = data["images"]
else:
images = data

images_shape = self.backend.shape(images)
if len(images_shape) == 3:
return None

batch_size = images_shape[0]
image_height = images_shape[self.height_axis]
image_width = images_shape[self.width_axis]

seed = seed or self._get_seed_generator(self.backend._backend)

mix_weight = self._generate_mix_weight(batch_size, seed)
ratio = self.backend.numpy.sqrt(1.0 - mix_weight)

x0, x1 = self._compute_crop_bounds(batch_size, image_width, ratio, seed)
y0, y1 = self._compute_crop_bounds(
batch_size, image_height, ratio, seed
)

batch_masks, mix_weight = self._generate_batch_mask(
images_shape,
(x0, x1, y0, y1),
)

permutation_order = self.backend.random.shuffle(
self.backend.numpy.arange(0, batch_size, dtype="int32"),
seed=seed,
)

return {
"permutation_order": permutation_order,
"batch_masks": batch_masks,
"mix_weight": mix_weight,
}

def _generate_batch_mask(self, images_shape, box_corners):
def _generate_grid_xy(image_height, image_width):
grid_y, grid_x = self.backend.numpy.meshgrid(
self.backend.numpy.arange(
image_height, dtype=self.compute_dtype
),
self.backend.numpy.arange(
image_width, dtype=self.compute_dtype
),
indexing="ij",
)
if self.data_format == "channels_last":
grid_y = self.backend.cast(
grid_y[None, :, :, None], dtype=self.compute_dtype
)
grid_x = self.backend.cast(
grid_x[None, :, :, None], dtype=self.compute_dtype
)
else:
grid_y = self.backend.cast(
grid_y[None, None, :, :], dtype=self.compute_dtype
)
grid_x = self.backend.cast(
grid_x[None, None, :, :], dtype=self.compute_dtype
)
return grid_x, grid_y

image_height, image_width = (
images_shape[self.height_axis],
images_shape[self.width_axis],
)
grid_x, grid_y = _generate_grid_xy(image_height, image_width)

x0, x1, y0, y1 = box_corners

x0 = x0[:, None, None, None]
y0 = y0[:, None, None, None]
x1 = x1[:, None, None, None]
y1 = y1[:, None, None, None]

batch_masks = (
(grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1)
)
batch_masks = self.backend.numpy.repeat(
batch_masks, images_shape[self.channel_axis], axis=self.channel_axis
)
mix_weight = 1.0 - (x1 - x0) * (y1 - y0) / (image_width * image_height)
return batch_masks, mix_weight

def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed):
crop_length = self.backend.cast(
crop_ratio * image_length, dtype=self.compute_dtype
)

start_pos = self.backend.random.uniform(
shape=[batch_size],
minval=0,
maxval=1,
dtype=self.compute_dtype,
seed=seed,
) * (image_length - crop_length)

end_pos = start_pos + crop_length

return start_pos, end_pos

def _generate_mix_weight(self, batch_size, seed):
alpha = (
self.backend.random.uniform(
shape=(),
minval=self.factor[0],
maxval=self.factor[1],
dtype=self.compute_dtype,
seed=seed,
)
+ 1e-6
)
mix_weight = self.backend.random.beta(
(batch_size,), alpha, alpha, seed=seed, dtype=self.compute_dtype
)
return mix_weight

def transform_images(self, images, transformation=None, training=True):
if training and transformation is not None:
images = self.backend.cast(images, self.compute_dtype)

permutation_order = transformation["permutation_order"]
batch_masks = transformation["batch_masks"]

images = self.backend.numpy.where(
batch_masks,
self.backend.numpy.take(images, permutation_order, axis=0),
images,
)
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
if training and transformation is not None:
permutation_order = transformation["permutation_order"]
mix_weight = transformation["mix_weight"]

cutout_labels = self.backend.numpy.take(
labels, permutation_order, axis=0
)
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])
labels = mix_weight * labels + (1.0 - mix_weight) * cutout_labels

return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
raise NotImplementedError()

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return self.transform_images(
segmentation_masks, transformation, training
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"factor": self.factor,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
85 changes: 85 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class CutMixTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.CutMix,
init_kwargs={
"factor": 1.0,
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
# StatelessRandomGammaV3 is not supported on XLA_GPU_JIT
run_training_check=not testing.tensorflow_uses_gpu(),
)

def test_cut_mix_inference(self):
seed = 3481
layer = layers.CutMix()

np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_cut_mix_basic(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image1 = np.ones((2, 2, 1))
image2 = np.zeros((2, 2, 1))
inputs = np.asarray([image1, image2])
expected_output = np.array(
[
[[[1.0], [1.0]], [[1.0], [1.0]]],
[[[0.0], [0.0]], [[0.0], [0.0]]],
]
)
else:
image1 = np.ones((1, 2, 2))
image2 = np.zeros((1, 2, 2))
inputs = np.asarray([image1, image2])
expected_output = np.asarray(
[
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]],
[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
]
)

layer = layers.CutMix(data_format=data_format)

transformation = {
"batch_masks": np.asarray(
[
[[[False], [True]], [[False], [False]]],
[[[False], [False]], [[True], [False]]],
]
),
"mix_weight": np.asarray([[[[0.7826548]]], [[[0.8133545]]]]),
"permutation_order": np.asarray([0, 1]),
}

output = layer.transform_images(inputs, transformation)

self.assertAllClose(expected_output, output)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.CutMix(data_format=data_format)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()
Loading