Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit_compile control in BaseImageAugmentationLayer #1541

Open
bhack opened this issue Apr 11, 2022 · 11 comments
Open

jit_compile control in BaseImageAugmentationLayer #1541

bhack opened this issue Apr 11, 2022 · 11 comments

Comments

@bhack
Copy link

bhack commented Apr 11, 2022

This is just a follow-up of
keras-team/keras-cv#165 (comment)

@qlzh727 What do you think about adding an extra parameter to the base class for jit_compile?
https://github.com/keras-team/keras/blob/master/keras/layers/preprocessing/image_preprocessing.py#L413-L414

So that we could optionally use something like:

f = def_function.function(self._augment, jit_compile=True)
self._map_fn(f, inputs)
@fchollet
Copy link
Member

Triage notes:

  • Such an option should likely not be exposed to end users, so it doesn't need to be an argument.
  • Some layers may require jit_compile in order to be performant. In such cases we should just make the tf.function(jit_compile=True) directly part of the layer implementation

@fchollet fchollet self-assigned this Apr 14, 2022
@bhack
Copy link
Author

bhack commented Apr 14, 2022

What was the logic to expose this in model compile instead?

@bhack
Copy link
Author

bhack commented Apr 14, 2022

I suppose that if we already let the user to jit_compile or not in the model compile API we don't want to automatically compile layers without any user control.

https://github.com/keras-team/keras/blob/39ad2c1cb22b231baf05a0218322328c13654bda/keras/engine/training.py#L532

@bhack
Copy link
Author

bhack commented May 27, 2022

/cc @qlzh727 @LukeWood
I suppose that we will have a small "explosion" of XLA jit_compile failures when we will enable the XLA compilation.
And they will be more fatal then the ones we have in keras-team/keras-cv#291 for tf.vectorized_map.

tf.vectorized_map has an auto-slowdown fallback effect but XLA instead has a fail fast policy so the first TF2XLA not implemented op that we use in a layer implementation it will go to totally break the jit compilation.

A sort of "fallback" is something different in XLA and it is light outside (GPU only) that it require to be implemented for every op in TF2XLA that you use in your implementation (then a CPU/TPU HLO implementation it is still required if you want to jit_compile on these devices).

As I am quite brand new to XLA internals /cc @cheshire in the case he want to add some advise.

@bhack
Copy link
Author

bhack commented Jul 9, 2022

@fchollet @qlzh727 Can you migrate this to keras-cv now? It seems @LukeWood has not enough rights in this repo for the migration.

@sachinprasadhs sachinprasadhs transferred this issue from keras-team/keras Sep 22, 2023
@sachinprasadhs sachinprasadhs transferred this issue from keras-team/tf-keras Sep 22, 2023
@sachinprasadhs sachinprasadhs self-assigned this Feb 14, 2024
@sachinprasadhs
Copy link
Collaborator

sachinprasadhs commented Feb 14, 2024

From Keras 3, jit_compile is set to auto in the model.compile, which means it will use XLA if the model allows it.
Can we close the issue, considering the Keras 3 implementation.

@bhack
Copy link
Author

bhack commented Feb 14, 2024

Is it binary for the library user? e.g. whole model compile or nothing?

@sachinprasadhs
Copy link
Collaborator

From the doc: https://keras.io/api/models/model_training_apis/

jit_compile: Bool or "auto". Whether to use XLA compilation when compiling a model. For jax and tensorflow backends, jit_compile="auto" enables XLA compilation if the model supports it, and disabled otherwise. For torch backend, "auto" will default to eager execution and jit_compile=True will run with torch.compile with the "inductor" backend.

@bhack
Copy link
Author

bhack commented Feb 14, 2024

But here we were talking about BaseImageAugmentationLayer not the whole model interface. In any case I don't know all the new refactors so do what you want with this ticket.

@sachinprasadhs
Copy link
Collaborator

BaseImageAugmentationLayer subclasses Keras Layer class and Keras Layer has an argument for operating in jit by default using self.supports_jit = True
Below are the code references.
https://github.com/keras-team/keras/blob/7ce3d62af7cc6959fc5a5841cfe17043dfcb8615/keras/layers/layer.py#L275

https://github.com/keras-team/keras/blob/7ce3d62af7cc6959fc5a5841cfe17043dfcb8615/keras/layers/layer.py#L275

Please check and close the issue if there is no question. Thanks

@bhack
Copy link
Author

bhack commented Feb 17, 2024

As self.supports_jit it is a bool how we control XLA vs torch.compile dynamo VS other compilers?
A layer that could be compiled for one stack doesn't meant that it could be compiled with another backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants