-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize ChannelShuffle #1433
Vectorize ChannelShuffle #1433
Conversation
In this PR, I find that channel shuffle cannot process independently. |
can we roll our own version instead of using |
Thanks for the keywords. I'm going to try |
I think we can do: num_channels = x.shape.as_list()[-1]
# Generate a random permutation of the channel indices
shuffled_indices = tf.random.shuffle(tf.range(num_channels), seed=seed)
# Permute the channels of the input tensor according to the random indices
shuffled_x = tf.gather(x, shuffled_indices, axis=-1) P.s. We could also use the stateless version: |
- replace tf.random.shuffle by tf.random.uniform and tf.argsort - add test of independence - remove numerical check
@LukeWood @bhack Much of the difficulty comes from how to independently shuffle indices in batch. I adopt the def get_random_transformation_batch(self, batch_size, **kwargs):
# get batched shuffled indices
# for example: batch_size=2; self.group=5
# indices = [
# [0, 2, 3, 4, 1],
# [4, 1, 0, 2, 3]
# ]
indices_distribution = tf.random.uniform(
(batch_size, self.groups), seed=self.seed
)
indices = tf.argsort(indices_distribution, axis=-1)
return indices and then gather by def augment_images(self, images, transformations, **kwargs):
......
indices = transformations
# append batch indexes next to shuffled indices
batch_indexs = tf.reshape(
tf.repeat(tf.range(batch_size), self.groups),
(batch_size, self.groups),
)
indices = tf.stack([batch_indexs, indices], axis=-1)
channels_per_group = num_channels // self.groups
images = tf.reshape(
images, [batch_size, height, width, self.groups, channels_per_group]
)
images = tf.transpose(images, perm=[0, 3, 1, 2, 4])
images = tf.gather_nd(images, indices=indices)
images = tf.transpose(images, perm=[0, 2, 3, 4, 1])
images = tf.reshape(images, [batch_size, height, width, num_channels])
return images All unit tests passed but it is difficult to check the consistency with non-vectorized implementation because the randomness comes from different sources ( Also a independence test is added. |
If I remember correctly is |
I think In ChannelShuffle, we want to shuffle channels independently in each image like: It is achievable to loop over each image with the operations you mentioned. outputs = tf.TensorArray(...)
for i in range(batch_size):
num_channels = images[i].shape.as_list()[-1]
shuffled_indices = tf.random.shuffle(tf.range(num_channels), seed=seed)
shuffled_x = tf.gather(images[i], shuffled_indices, axis=-1)
outputs.write(i, shuffled_x)
outputs.stack() But it is too slow that it is even slower than non-vectorized version in eager mode. I cannot figure out how to vectorize this with |
Something like this? |
https://colab.research.google.com/drive/1b-tZErA-jCv6SNurdprdOYc6dLFr0ZSH?usp=sharing
|
Ok, so you meant within the batch. Sorry I was confused what kind of randomization policy we finally selected for the new vectorizing refactoring (mandatory within the batch randomization or not) |
Kindly ping @LukeWood @ianstenbit The idea of vectorizing please see #1433 (comment) The current implementation shuffles channel (by groups) independently for each image in the batch |
@LukeWood PTAL -- what's the decision re: independent augmentation in batches? |
All augmentation should be done independently on a per-image basis. |
/gcbrun Thanks you @james77777778 !!! |
Sorry one more thing: Can add a test like the one in https://github.com/keras-team/keras-cv/pull/1480/files#diff-1c771135874332a25b437da4a5706faea85ff719331dfe7beb255403a1a66477 |
Hi @LukeWood EDITED: |
- `test_config_with_custom_name` - `test_output_dtypes` - `test_config`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
) | ||
|
||
|
||
class OldChannelShuffle(BaseImageAugmentationLayer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless the behavior is changing as part of this PR, we should add a PyTest to this benchmark to verify that the old/new implementations are numerically identical (like this one)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LukeWood to confirm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see #1433 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great -- thank you!
Hi @ianstenbit Luckily, I got a magic number ( Also I ran the script |
/gcbrun |
GCB CI tests failing due to unrelated changed, should be green post-merge. |
* Vectorize ChannelShuffle * Update ChannelShuffle. - replace tf.random.shuffle by tf.random.uniform and tf.argsort - add test of independence - remove numerical check * Fix typo * add tests. - `test_config_with_custom_name` - `test_output_dtypes` - `test_config` * Add unit test to verify numerical consistency * Fix ChannelShuffle
What does this PR do?
Fixes #1414
Benchmark results:
Runtime gets high at some points due to retracing and some warning.
Here is the log.txt
Overall, I think vectorized implementation beats old one.
To check the consistency, I manually set batch_size=1 to make
tf.random.shuffle
acts the same for new and old implementation:Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@LukeWood @ianstenbit