Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize ChannelShuffle #1433

Merged
merged 7 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions benchmarks/vectorized_channel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from keras_cv.layers import ChannelShuffle
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)


class OldChannelShuffle(BaseImageAugmentationLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless the behavior is changing as part of this PR, we should add a PyTest to this benchmark to verify that the old/new implementations are numerically identical (like this one)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LukeWood to confirm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #1433 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great -- thank you!

"""Shuffle channels of an input image.

Input shape:
The expected images should be [0-255] pixel ranges.
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format

Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format

Args:
groups: Number of groups to divide the input channels. Default 3.
seed: Integer. Used to create a random seed.

Call arguments:
inputs: Tensor representing images of shape
`(batch_size, width, height, channels)`, with dtype tf.float32 / tf.uint8,
` or (width, height, channels)`, with dtype tf.float32 / tf.uint8
training: A boolean argument that determines whether the call should be run
in inference mode or training mode. Default: True.

Usage:
```python
(images, labels), _ = tf.keras.datasets.cifar10.load_data()
channel_shuffle = keras_cv.layers.ChannelShuffle()
augmented_images = channel_shuffle(images)
```
"""

def __init__(self, groups=3, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)
self.groups = groups
self.seed = seed

def augment_image(self, image, transformation=None, **kwargs):
shape = tf.shape(image)
height, width = shape[0], shape[1]
num_channels = image.shape[2]

if not num_channels % self.groups == 0:
raise ValueError(
"The number of input channels should be "
"divisible by the number of groups."
f"Received: channels={num_channels}, groups={self.groups}"
)

channels_per_group = num_channels // self.groups
image = tf.reshape(
image, [height, width, self.groups, channels_per_group]
)
image = tf.transpose(image, perm=[2, 0, 1, 3])
image = tf.random.shuffle(image, seed=self.seed)
image = tf.transpose(image, perm=[1, 2, 3, 0])
image = tf.reshape(image, [height, width, num_channels])

return image

def augment_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes

def augment_label(self, label, transformation=None, **kwargs):
return label

def augment_segmentation_mask(
self, segmentation_mask, transformation, **kwargs
):
return segmentation_mask

def get_config(self):
config = super().get_config()
config.update({"groups": self.groups, "seed": self.seed})
return config

def compute_output_shape(self, input_shape):
return input_shape


class ChannelShuffleTest(tf.test.TestCase):
def test_consistency_with_old_impl(self):
image_shape = (1, 32, 32, 3)
groups = 3
fixed_seed = 2023 # magic number
image = tf.random.uniform(shape=image_shape)

layer = ChannelShuffle(groups=groups, seed=fixed_seed)
old_layer = OldChannelShuffle(groups=groups, seed=fixed_seed)

output = layer(image)
old_output = old_layer(image)

self.assertNotAllClose(image, output)
self.assertAllClose(old_output, output)


if __name__ == "__main__":
# Run benchmark
(x_train, _), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype(np.float32)

num_images = [1000, 2000, 3000, 4000, 5000, 10000]
results = {}
aug_candidates = [ChannelShuffle, OldChannelShuffle]
aug_args = {"groups": 3}

for aug in aug_candidates:
# Eager Mode
c = aug.__name__
layer = aug(**aug_args)
runtimes = []
print(f"Timing {c}")

for n_images in num_images:
# warmup
layer(x_train[:n_images])

t0 = time.time()
r1 = layer(x_train[:n_images])
t1 = time.time()
runtimes.append(t1 - t0)
print(f"Runtime for {c}, n_images={n_images}: {t1-t0}")
results[c] = runtimes

# Graph Mode
c = aug.__name__ + " Graph Mode"
layer = aug(**aug_args)

@tf.function()
def apply_aug(inputs):
return layer(inputs)

runtimes = []
print(f"Timing {c}")

for n_images in num_images:
# warmup
apply_aug(x_train[:n_images])

t0 = time.time()
r1 = apply_aug(x_train[:n_images])
t1 = time.time()
runtimes.append(t1 - t0)
print(f"Runtime for {c}, n_images={n_images}: {t1-t0}")
results[c] = runtimes

# XLA Mode
c = aug.__name__ + " XLA Mode"
layer = aug(**aug_args)

@tf.function(jit_compile=True)
def apply_aug(inputs):
return layer(inputs)

runtimes = []
print(f"Timing {c}")

for n_images in num_images:
# warmup
apply_aug(x_train[:n_images])

t0 = time.time()
r1 = apply_aug(x_train[:n_images])
t1 = time.time()
runtimes.append(t1 - t0)
print(f"Runtime for {c}, n_images={n_images}: {t1-t0}")
results[c] = runtimes

plt.figure()
for key in results:
plt.plot(num_images, results[key], label=key)
plt.xlabel("Number images")

plt.ylabel("Runtime (seconds)")
plt.legend()
plt.savefig("comparison.png")

# So we can actually see more relevant margins
del results[aug_candidates[1].__name__]
plt.figure()
for key in results:
plt.plot(num_images, results[key], label=key)
plt.xlabel("Number images")

plt.ylabel("Runtime (seconds)")
plt.legend()
plt.savefig("comparison_no_old_eager.png")

# Run unit tests
tf.test.main()
104 changes: 65 additions & 39 deletions keras_cv/layers/preprocessing/channel_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The KerasCV Authors
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,20 +14,18 @@

import tensorflow as tf

from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import (
VectorizedBaseImageAugmentationLayer,
)


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class ChannelShuffle(BaseImageAugmentationLayer):
class ChannelShuffle(VectorizedBaseImageAugmentationLayer):
"""Shuffle channels of an input image.

Input shape:
The expected images should be [0-255] pixel ranges.
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format

Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format
Expand All @@ -36,17 +34,10 @@ class ChannelShuffle(BaseImageAugmentationLayer):
groups: Number of groups to divide the input channels. Default 3.
seed: Integer. Used to create a random seed.

Call arguments:
inputs: Tensor representing images of shape
`(batch_size, width, height, channels)`, with dtype tf.float32 / tf.uint8,
` or (width, height, channels)`, with dtype tf.float32 / tf.uint8
training: A boolean argument that determines whether the call should be run
in inference mode or training mode. Default: True.

Usage:
```python
(images, labels), _ = tf.keras.datasets.cifar10.load_data()
channel_shuffle = keras_cv.layers.ChannelShuffle()
channel_shuffle = ChannelShuffle(groups=3)
augmented_images = channel_shuffle(images)
```
"""
Expand All @@ -56,10 +47,40 @@ def __init__(self, groups=3, seed=None, **kwargs):
self.groups = groups
self.seed = seed

def augment_image(self, image, transformation=None, **kwargs):
shape = tf.shape(image)
height, width = shape[0], shape[1]
num_channels = image.shape[2]
def get_random_transformation_batch(self, batch_size, **kwargs):
# get batched shuffled indices
# for example: batch_size=2; self.group=5
# indices = [
# [0, 2, 3, 4, 1],
# [4, 1, 0, 2, 3]
# ]
indices_distribution = self._random_generator.random_uniform(
(batch_size, self.groups)
)
indices = tf.argsort(indices_distribution, axis=-1)
return indices

def augment_ragged_image(self, image, transformation, **kwargs):
# self.augment_images must have
# 4D images (batch_size, height, width, channel)
# 2D transformations (batch_size, groups)
image = tf.expand_dims(image, axis=0)
transformation = tf.expand_dims(transformation, axis=0)
image = self.augment_images(
images=image, transformations=transformation, **kwargs
)
return tf.squeeze(image, axis=0)

def augment_images(self, images, transformations, **kwargs):
batch_size = tf.shape(images)[0]
height, width = images.shape[1], images.shape[2]
num_channels = images.shape[3]
indices = transformations

# append batch indexes next to shuffled indices
batch_indexs = tf.repeat(tf.range(batch_size), self.groups)
batch_indexs = tf.reshape(batch_indexs, (batch_size, self.groups))
indices = tf.stack([batch_indexs, indices], axis=-1)

if not num_channels % self.groups == 0:
raise ValueError(
Expand All @@ -69,31 +90,36 @@ def augment_image(self, image, transformation=None, **kwargs):
)

channels_per_group = num_channels // self.groups
image = tf.reshape(
image, [height, width, self.groups, channels_per_group]
)
image = tf.transpose(image, perm=[2, 0, 1, 3])
image = tf.random.shuffle(image, seed=self.seed)
image = tf.transpose(image, perm=[1, 2, 3, 0])
image = tf.reshape(image, [height, width, num_channels])

return image
images = tf.reshape(
images, [batch_size, height, width, self.groups, channels_per_group]
)
images = tf.transpose(images, perm=[0, 3, 1, 2, 4])
images = tf.gather_nd(images, indices=indices)
images = tf.transpose(images, perm=[0, 2, 3, 4, 1])
images = tf.reshape(images, [batch_size, height, width, num_channels])

def augment_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes
return images

def augment_label(self, label, transformation=None, **kwargs):
return label
def augment_labels(self, labels, transformations, **kwargs):
return labels

def augment_segmentation_mask(
self, segmentation_mask, transformation, **kwargs
def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_mask
return segmentation_masks

def get_config(self):
config = super().get_config()
config.update({"groups": self.groups, "seed": self.seed})
return config
def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs):
return bounding_boxes

def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
"groups": self.groups,
"seed": self.seed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config):
return cls(**config)
Loading