Skip to content

WIP: fix RandAugment behavior in tf graph mode #21185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 114 additions & 120 deletions keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random

import keras.src.layers as layers
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
Expand Down Expand Up @@ -37,20 +36,6 @@ class RandAugment(BaseImagePreprocessingLayer):
_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

_AUGMENT_LAYERS = [
"random_shear",
"random_translation",
"random_rotation",
"random_brightness",
"random_color_degeneration",
"random_contrast",
"random_sharpness",
"random_posterization",
"solarization",
"auto_contrast",
"equalization",
]

def __init__(
self,
value_range=(0, 255),
Expand All @@ -70,92 +55,84 @@ def __init__(
self.seed = seed
self.generator = SeedGenerator(seed)

self.random_shear = layers.RandomShear(
x_factor=self.factor,
y_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_translation = layers.RandomTranslation(
height_factor=self.factor,
width_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_rotation = layers.RandomRotation(
factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_brightness = layers.RandomBrightness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_color_degeneration = layers.RandomColorDegeneration(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_contrast = layers.RandomContrast(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_sharpness = layers.RandomSharpness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.solarization = layers.Solarization(
addition_factor=self.factor,
threshold_factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_posterization = layers.RandomPosterization(
factor=max(1, int(8 * self.factor[1])),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.auto_contrast = layers.AutoContrast(
value_range=self.value_range, data_format=data_format, **kwargs
)

self.equalization = layers.Equalization(
value_range=self.value_range, data_format=data_format, **kwargs
)
self.augmentations = [
layers.RandomShear(
x_factor=self.factor,
y_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomTranslation(
height_factor=self.factor,
width_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomRotation(
factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomBrightness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomColorDegeneration(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomContrast(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomSharpness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.Solarization(
addition_factor=self.factor,
threshold_factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.RandomPosterization(
factor=max(1, int(8 * self.factor[1])),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
),
layers.AutoContrast(
value_range=self.value_range, data_format=data_format, **kwargs
),
layers.Equalization(
value_range=self.value_range, data_format=data_format, **kwargs
)
]
self.num_layers = len(self.augmentations)

def build(self, input_shape):
for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
for augmentation_layer in self.augmentations:
augmentation_layer.build(input_shape)

def get_random_transformation(self, data, training=True, seed=None):
Expand All @@ -165,34 +142,55 @@ def get_random_transformation(self, data, training=True, seed=None):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
for augmentation_layer in self.augmentations:
augmentation_layer.backend.set_backend("tensorflow")

transformation = {}
random.shuffle(self._AUGMENT_LAYERS)
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
augmentation_layer = getattr(self, layer_name)
transformation[layer_name] = (
transformation = []
idx = self.backend.random.shuffle(
self.backend.numpy.arange(self.num_layers, dtype="int32"),
seed=self._get_seed_generator(self.backend._backend),
)

for i in range(self.num_layers):
augmentation_layer = self.augmentations[i]
transformation.append(
augmentation_layer.get_random_transformation(
data,
training=training,
seed=self._get_seed_generator(self.backend._backend),
)
)

return transformation

return idx, transformation

def _apply_augs(self, transformation, func_name, inputs):
aug_index, transforms = transformation


def get_fn(aug, xform):
def func(x):
if isinstance(x, dict):
z = tree.map_structure(self.backend.numpy.copy, x)
return getattr(aug, func_name)(z, xform)
return getattr(aug, func_name)(x, xform)
return func

def body(i, loop_var):
idx = aug_index[i]
return self.backend.core.switch(
idx,
[get_fn(aug, xform) for aug, xform in zip(self.augmentations, transforms)],
loop_var,
)

return self.backend.core.fori_loop(0, self.num_ops, body, inputs)

def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
images = augmentation_layer.transform_images(
images, transformation_value
)

images = self._apply_augs(transformation, "transform_images", images)

images = self.backend.cast(images, self.compute_dtype)
return images

Expand All @@ -206,11 +204,7 @@ def transform_bounding_boxes(
training=True,
):
if training:
for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
bounding_boxes = augmentation_layer.transform_bounding_boxes(
bounding_boxes, transformation_value, training=training
)
bounding_boxes = self._apply_augs(transformation, "transform_bounding_boxes", bounding_boxes)
return bounding_boxes

def transform_segmentation_masks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,64 @@ def test_rand_augment_tf_data_bounding_boxes(self):
bounding_box_format="xyxy",
)
ds.map(layer)

def test_rand_augment_tf_graph_mode(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((4, 8, 8, 3))
else:
input_data = np.random.random((4, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format, seed=42, num_ops=1)

# using tf.data.Dataset.map applies the function in graph mode
# lambda gets shuffled transform index
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(
lambda x: layer.get_random_transformation(x)[0]
)
results = []
for output in ds:
results.append(output.numpy())
self.assertFalse(np.all(results[0] == results[1]))

def test_rand_augment_tf_graph_mode_2(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((8, 8, 8, 3))
else:
input_data = np.random.random((8, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format, seed=42, num_ops=1)

# using tf.data.Dataset.map applies the function in graph mode
# lambda gets shuffled transform index
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(
lambda x: layer.get_random_transformation(x)[0]
)
results = set()
for output in ds:
for i in output.numpy():
results.add(layer.augmentations[i].name)
print(results)
self.assertTrue(len(results) > 1)

def test_rand_augment_model(self):
from keras.src.models import Sequential

data_format = backend.config.image_data_format()
N = 32
if data_format == "channels_last":
input_data = np.random.random((N, 8, 8, 3))
else:
input_data = np.random.random((N, 3, 8, 8))
y_true = np.random.random((N, 1))

model = Sequential([
layers.Input(input_data.shape[1:], dtype="float32"),
layers.RandAugment(data_format=data_format, seed=42, num_ops=2),
layers.Flatten(),
layers.Dense(10, activation="relu"),
layers.Dense(1),
])
model.compile(loss="mse")
model.summary()

model.fit(input_data, y_true, batch_size=2, epochs=40)
Loading