Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize ChannelShuffle #1433

Merged
merged 7 commits into from
Mar 17, 2023
Merged

Conversation

james77777778
Copy link
Contributor

What does this PR do?

Fixes #1414

Benchmark results:

comparison
comparison_no_old_eager

Runtime gets high at some points due to retracing and some warning.

WARNING:tensorflow:Using a while_loop for converting RandomShuffle cause there is no registered converter for this op.

Here is the log.txt

Overall, I think vectorized implementation beats old one.

To check the consistency, I manually set batch_size=1 to make tf.random.shuffle acts the same for new and old implementation:

class ChannelShuffleTest(tf.test.TestCase):
    def test_consistency_with_old_impl(self):
        # must set batch_size=1 due to randomness from
        # images = tf.random.shuffle(images, seed=self.seed)
        image_shape = (1, 32, 32, 3)
        fixed_seed = 2023
        image = tf.random.uniform(shape=image_shape) * 255.0

        layer = ChannelShuffle(groups=3, seed=fixed_seed)
        old_layer = OldChannelShuffle(groups=3, seed=fixed_seed)

        output = layer(image)
        old_output = old_layer(image)

        self.assertAllClose(old_output, output)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

@LukeWood @ianstenbit

@james77777778
Copy link
Contributor Author

In this PR, I find that channel shuffle cannot process independently.
Need some time to survey but it could be difficult because tf.random.shuffle cannot apply on batched inputs.

@LukeWood
Copy link
Contributor

In this PR, I find that channel shuffle cannot process independently. Need some time to survey but it could be difficult because tf.random.shuffle cannot apply on batched inputs.

can we roll our own version instead of using tf.random.shuffle? I believe this could be done with tf.gather_nd() and tf.random.uniform().

@james77777778
Copy link
Contributor Author

In this PR, I find that channel shuffle cannot process independently. Need some time to survey but it could be difficult because tf.random.shuffle cannot apply on batched inputs.

can we roll our own version instead of using tf.random.shuffle? I believe this could be done with tf.gather_nd() and tf.random.uniform().

Thanks for the keywords.
I have tested tf.TensorArray (as a output list) with similar implementation from previous one but the performance decreased a lot...

I'm going to try tf.gather_nd() and tf.random.uniform()

@bhack
Copy link
Contributor

bhack commented Feb 22, 2023

I think we can do:

num_channels = x.shape.as_list()[-1]

# Generate a random permutation of the channel indices
shuffled_indices = tf.random.shuffle(tf.range(num_channels), seed=seed)

# Permute the channels of the input tensor according to the random indices
shuffled_x = tf.gather(xshuffled_indicesaxis=-1)

P.s. We could also use the stateless version:
tf.random.experimental.stateless_shuffle

- replace tf.random.shuffle by tf.random.uniform and tf.argsort
- add test of independence
- remove numerical check
@james77777778
Copy link
Contributor Author

@LukeWood @bhack
Thanks for your guidance

Much of the difficulty comes from how to independently shuffle indices in batch.

I adopt the rand+argsort trick based on
https://stackoverflow.com/a/55317373

    def get_random_transformation_batch(self, batch_size, **kwargs):
        # get batched shuffled indices
        # for example: batch_size=2; self.group=5
        # indices = [
        #     [0, 2, 3, 4, 1],
        #     [4, 1, 0, 2, 3]
        # ]
        indices_distribution = tf.random.uniform(
            (batch_size, self.groups), seed=self.seed
        )
        indices = tf.argsort(indices_distribution, axis=-1)
        return indices

and then gather by

    def augment_images(self, images, transformations, **kwargs):
        ......
        indices = transformations

        # append batch indexes next to shuffled indices
        batch_indexs = tf.reshape(
            tf.repeat(tf.range(batch_size), self.groups),
            (batch_size, self.groups),
        )
        indices = tf.stack([batch_indexs, indices], axis=-1)

        channels_per_group = num_channels // self.groups
        images = tf.reshape(
            images, [batch_size, height, width, self.groups, channels_per_group]
        )
        images = tf.transpose(images, perm=[0, 3, 1, 2, 4])
        images = tf.gather_nd(images, indices=indices)
        images = tf.transpose(images, perm=[0, 2, 3, 4, 1])
        images = tf.reshape(images, [batch_size, height, width, num_channels])
        return images

All unit tests passed but it is difficult to check the consistency with non-vectorized implementation because the randomness comes from different sources (tf.random.shuffle, tf.random.uniform).
I just removed the test in benchmark script.

Also a independence test is added.

The performance result:
comparison

@bhack
Copy link
Contributor

bhack commented Feb 23, 2023

Much of the difficulty comes from how to independently shuffle indices in batch.

If I remember correctly is tf.random.experimental.stateless_shuffle

@james77777778
Copy link
Contributor Author

Much of the difficulty comes from how to independently shuffle indices in batch.

If I remember correctly is tf.random.experimental.stateless_shuffle

I think tf.random.experimental.stateless_shuffle acts the same as tf.random.shuffle
So they can only shuffle tensor along 1 axis?

In ChannelShuffle, we want to shuffle channels independently in each image like:
idx=1 => shuffled_idx = [1, 2, 0]
idx=2 => shuffled_idx = [0, 2, 1]
...

It is achievable to loop over each image with the operations you mentioned.
Pseudo code like:

outputs = tf.TensorArray(...)
for i in range(batch_size):
    num_channels = images[i].shape.as_list()[-1]
    shuffled_indices = tf.random.shuffle(tf.range(num_channels), seed=seed)
    shuffled_x = tf.gather(images[i], shuffled_indices, axis=-1)
    outputs.write(i, shuffled_x)
outputs.stack()

But it is too slow that it is even slower than non-vectorized version in eager mode.

I cannot figure out how to vectorize this with tf.random.shuffle or tf.random.experimental.stateless_shuffle. Please enlighten me if I missed something.

@bhack
Copy link
Contributor

bhack commented Feb 23, 2023

In ChannelShuffle, we want to shuffle channels independently in each image like:
idx=1 => shuffled_idx = [1, 2, 0]
idx=2 => shuffled_idx = [0, 2, 1]

Something like this?
https://colab.research.google.com/gist/bhack/65760a7896ee9f6ec450ebbe8bd32ca2/exp_random_shuffle.ipynb

@james77777778
Copy link
Contributor Author

james77777778 commented Feb 23, 2023

In ChannelShuffle, we want to shuffle channels independently in each image like:
idx=1 => shuffled_idx = [1, 2, 0]
idx=2 => shuffled_idx = [0, 2, 1]

Something like this? https://colab.research.google.com/gist/bhack/65760a7896ee9f6ec450ebbe8bd32ca2/exp_random_shuffle.ipynb

https://colab.research.google.com/drive/1b-tZErA-jCv6SNurdprdOYc6dLFr0ZSH?usp=sharing
I manually set the values to show that

  • only 1 shuffled indices within the batch
  • outputs are shuffled by same shuffled indices and not independently across the batch

@bhack
Copy link
Contributor

bhack commented Feb 23, 2023

Ok, so you meant within the batch. Sorry I was confused what kind of randomization policy we finally selected for the new vectorizing refactoring (mandatory within the batch randomization or not)

@ianstenbit ianstenbit requested a review from LukeWood March 3, 2023 19:40
@james77777778
Copy link
Contributor Author

Kindly ping @LukeWood @ianstenbit

The idea of vectorizing please see #1433 (comment)

The current implementation shuffles channel (by groups) independently for each image in the batch

@ianstenbit
Copy link
Contributor

@LukeWood PTAL -- what's the decision re: independent augmentation in batches?

@LukeWood
Copy link
Contributor

LukeWood commented Mar 7, 2023

All augmentation should be done independently on a per-image basis.

@LukeWood
Copy link
Contributor

LukeWood commented Mar 7, 2023

/gcbrun

Thanks you @james77777778 !!!

@LukeWood
Copy link
Contributor

@james77777778
Copy link
Contributor Author

james77777778 commented Mar 15, 2023

- `test_config_with_custom_name`
- `test_output_dtypes`
- `test_config`
Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

)


class OldChannelShuffle(BaseImageAugmentationLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless the behavior is changing as part of this PR, we should add a PyTest to this benchmark to verify that the old/new implementations are numerically identical (like this one)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LukeWood to confirm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #1433 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great -- thank you!

@james77777778
Copy link
Contributor Author

Hi @ianstenbit

Luckily, I got a magic number (seed=2023 😂) that has the same randomness for old/new implementation. I have added the unit test to the benchmark.

Also I ran the script examples/layers/preprocessing/classification/channel_shuffle_demo.py to visualize the output of new ChannelShuffle

output

@ianstenbit
Copy link
Contributor

/gcbrun

@ianstenbit
Copy link
Contributor

GCB CI tests failing due to unrelated changed, should be green post-merge.

@ianstenbit ianstenbit merged commit 10aac22 into keras-team:master Mar 17, 2023
@james77777778 james77777778 deleted the channel-shuffle branch March 18, 2023 05:37
ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
* Vectorize ChannelShuffle

* Update ChannelShuffle.
- replace tf.random.shuffle by tf.random.uniform and tf.argsort
- add test of independence
- remove numerical check

* Fix typo

* add tests.
- `test_config_with_custom_name`
- `test_output_dtypes`
- `test_config`

* Add unit test to verify numerical consistency

* Fix ChannelShuffle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorize ChannelShuffle
4 participants