Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces MaybeApply layer. #435

Merged

Conversation

sebastian-sz
Copy link
Contributor

Closes keras-team/keras#422 .

Maybe Apply layer is a wrapper around native Keras layers or BaseImageAugmentationLayer that applies, specified operation to random samples in a batch.

Example usage

images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])

zero_out = tf.keras.layers.Lambda(lambda x: 0 * x)
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234)

outputs = maybe_apply(images)
print(outputs)  # Random samples in a batch have been zero'ed out.

Performance / overhead

The layer introduces an overhead for the layer it wraps. It looks like the layer wrapped in MaybeApply takes ~2x the time to execute. I'm not sure if any improvements here can be made - I tried to implement the same behaviour using tf.gather_nd + tf.scatter_nd_update but got similar or even worse results for larger batches.

Below are latency (ms) measurements for Solarization and Posterization. Both were benchmarked with XLA to avoid tf.function retracing.

Solarization + MaybeApply Overhead
Posterization + MaybeApply overhead

Known Issues

The layer throws error in XLA with auto_vectorize=True, regardless of what layer it wraps. I'm not really sure why

InvalidArgumentError: Reading input as constant from a dynamic tensor is not yet supported. Xla shape: s32[<=32]

but it works with auto_vectorize=False.

@bhack
Copy link
Contributor

bhack commented May 15, 2022

It looks like the layer wrapped in MaybeApply takes ~2x the time to execute

You could try to trace/profile the code to collect more insight:
https://www.tensorflow.org/api_docs/python/tf/profiler/experimental/Trace

@sebastian-sz
Copy link
Contributor Author

@bhack
Running with XLA for both MaybeApply and raw layer I'm getting:
Your program is NOT input-bound because only 0.0% of the total step time sampled is waiting for input.
And it says that almost entire time is spent on Host Compute Time. Still this is e.g. 12.2 ms for native layer and 25.3 ms for MaybeApply wrapped.

However, running without XLA: I'm getting a different message for MaybeApply:
Your program is POTENTIALLY input-bound because 36.5% of the total step time sampled is spent on 'All Others' time (which could be due to I/O or Python execution or both).

I am not reading any data from the disk, so I'd assume it could be Python's overhead? Not sure.

@bhack
Copy link
Contributor

bhack commented May 15, 2022

Do you have a gist with tf.profile to reproduce this?

@sebastian-sz
Copy link
Contributor Author

@bhack this is what I'm using:
benchmark_maybe_apply.zip

@bhack
Copy link
Contributor

bhack commented May 15, 2022

Known Issues
The layer throws error in XLA with auto_vectorize=True, regardless of what layer it wraps. I'm not really sure why
InvalidArgumentError: Reading input as constant from a dynamic tensor is not yet supported. Xla shape: s32[<=32]

Yes it Is:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/xla_op_kernel.cc#L138-L146

Do you have, can you post, Tensorboard screenshots for the ops section for plain vs maybe wrapper?

The ops tables are like these:
#165 (comment)

More in general with XLA is also interesting to check how well the graph ops that we have used in the implementaton are optimized-fused or not with the HLO dump (not all the possible ops permutation sequences have the same XLA coverage and optimizzation quality).
See:
#141 (comment)

@bhack
Copy link
Contributor

bhack commented May 15, 2022

@bhack this is what I'm using:
benchmark_maybe_apply.zip

You cannot do like you have done when you don't use a Keras model. If you don't want to use a model check:
https://www.tensorflow.org/tensorboard/graphs#graphs_of_tffunctions

But as tensorflow/tensorboard#1961 is still open it is better that you still use a model.

@sebastian-sz
Copy link
Contributor Author

@bhack Instead of running the layer, I should wrap it in a tf.keras.Model, compile and run .predict ?

I'm still struggling to create a HLO graph. Will post when I gather more information.

@bhack
Copy link
Contributor

bhack commented May 16, 2022

@bhack Instead of running the layer, I should wrap it in a tf.keras.Model, compile and run .predict ?

Yes as the graph is on model build:
https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L401-L405

I'm still struggling to create a HLO graph. Will post when I gather more information.

https://www.tensorflow.org/xla#inspect_compiled_programs

@LukeWood
Copy link
Contributor

Closes keras-team/keras#422 .

Maybe Apply layer is a wrapper around native Keras layers or BaseImageAugmentationLayer that applies, specified operation to random samples in a batch.

Example usage

images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])

zero_out = tf.keras.layers.Lambda(lambda x: 0 * x)
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234)

outputs = maybe_apply(images)
print(outputs)  # Random samples in a batch have been zero'ed out.

Performance / overhead

The layer introduces an overhead for the layer it wraps. It looks like the layer wrapped in MaybeApply takes ~2x the time to execute. I'm not sure if any improvements here can be made - I tried to implement the same behaviour using tf.gather_nd + tf.scatter_nd_update but got similar or even worse results for larger batches.

Below are latency (ms) measurements for Solarization and Posterization. Both were benchmarked with XLA to avoid tf.function retracing.

Solarization + MaybeApply Overhead Posterization + MaybeApply overhead

Known Issues

The layer throws error in XLA with auto_vectorize=True, regardless of what layer it wraps. I'm not really sure why

InvalidArgumentError: Reading input as constant from a dynamic tensor is not yet supported. Xla shape: s32[<=32]

but it works with auto_vectorize=False.

I feel like we need a page in our docs for performance related artifacts like these! These are great to have.

@sebastian-sz
Copy link
Contributor Author

@bhack Following your suggestion on using the tf.keras.Model class I did the benchmarks again using .predict_on_batch and model.compile(jit_compile=...).

When using tf.keras.Model the wrapper layer runs in a similar amount of time as the regular layer.
XLA mean latency (ms)
Eager mean latency (ms)

benchmark_model.zip

@bhack
Copy link
Contributor

bhack commented May 17, 2022

@sebastian-sz Can you post the same graph without XLA?

As we still don't expose an API to control the XLA compilation x layer in Keras/Keras-cv (see keras-team/keras-io#1541) using just model.compile(jit_compile=...). is very risky cause any single ops not supported by XLA will let the model compilation to fail (see #146 (comment)).

Extra: Without extending Keras-cv layers test for XLA compilation we will never know the exact list of layers with one or more usupported XLA ops.
Also the list/inventory in TF Docs hasn't been updated for years: tensorflow/tensorflow#14798 (comment)

@bhack
Copy link
Contributor

bhack commented May 17, 2022

P.s. Just to clarify I meant in graph mode as model.compile default but without XLA instread of model.compile eager that I suppose produced the eager graph in your previous post.

Edit:
Checking your ZIP and the XLA boolean I suppose that your 2nd graph is graph mode without XLA instead of eager mode also at it seems to me too fast to be eager. Right?

@sebastian-sz
Copy link
Contributor Author

@bhack

Edit:
Checking your ZIP and the XLA boolean I suppose that your 2nd graph is graph mode without XLA instead of eager mode also at it seems to me too fast to be eager. Right?

Thanks, yes, my bad. Added option for run_eagerly. Eager is slower than graph / XLA (as expected) and the difference between layers is also small:
Corrected eager mode mean latency  ms

@sebastian-sz Can you post the same graph without XLA?

I'm not sure I follow - Is the above graph what you requested?

@bhack
Copy link
Contributor

bhack commented May 17, 2022

I'm not sure I follow - Is the above graph what you requested?

No It was the previous one labeled as eager but It was in graph mode without jit compile

with self.assertRaises(ValueError):
MaybeApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you pass a seed so that this test is not potentially flaky? Given, it is 1/2^32 flakiness, but still may as well seed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added seed to rng on line 37.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! does this seed the layer too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Added seed param to layer as well.

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 minor comments then good to go.

@LukeWood LukeWood merged commit 282c66c into keras-team:master May 19, 2022
@LukeWood
Copy link
Contributor

Looks good to me @sebastian-sz Thanks for the contribution!

@sebastian-sz sebastian-sz deleted the feature-422/add-maybe-apply-layer branch May 19, 2022 05:51
ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
* Added MaybeApply layer.

* Changed MaybeApply to override _augment method.

* Added seed to maybe_apply_test random generator.

* Added seed to layer in batched input test.

* Fixed MaybeApply docs.
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
* Added MaybeApply layer.

* Changed MaybeApply to override _augment method.

* Added seed to maybe_apply_test random generator.

* Added seed to layer in batched input test.

* Fixed MaybeApply docs.
freedomtan pushed a commit to freedomtan/keras-cv that referenced this pull request Jul 20, 2023
Silly bug, we were literally just adding the trainable field twice
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants