Added opt-in JAX threaded epoch iterator#22942
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a threaded prefetch iterator for JAX-based training, enabling asynchronous data loading to improve performance. It replaces direct JAXEpochIterator instantiation with a build_jax_epoch_iterator factory and adds necessary resource management through close() methods. The review feedback highlights improvements for PEP 8 compliance, such as moving imports to the top level, cleaning up redundant instance attributes, addressing a potential race condition in the close() method, and improving the clarity of error messages.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22942 +/- ##
==========================================
- Coverage 84.63% 83.97% -0.67%
==========================================
Files 464 464
Lines 68245 68346 +101
Branches 11186 11202 +16
==========================================
- Hits 57759 57391 -368
- Misses 7558 8046 +488
+ Partials 2928 2909 -19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I added tests to improve the coverage. The issue with the unit tests was probably due to a kubernetes timeout and solve simply rerunning the checks. Let me know if there anything to change |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks for your interest in contributing to Keras.
First, I have a couple of questions about the result and methodology:
- I have a hard time believing that the
0.51%improvement, to take one example, is statistically significant. When I try to benchmark anything, I get much bigger variations. - My understanding of the benchmarking code is that the dataset is just a single string that is split. I don't think this is a realistic scenario to optimize for. Realistic datasets are large. Also, in many cases the dataset are numerical, not strings.
Hi @hertschuh,
As soon as I have some time I can re-run the profiling to show the difference in the traces |
Description
I have added an opt-in alternative to JAXEpochIterator, using a separate thread.
To enable it the env variable
KERAS_JAX_EPOCH_ITERATOR=threadedshould be set.I have run some benchmark with some interesting results starting from the offical benchmarks
(the benchmark code can be found here: https://github.com/AmedeoBiolatti/keras-benchmarks/tree/jax_epoch_iterator_benchmark)
(wandb runs are also available if needed)
bert_fit, batch_size=8bert_fit, batch_size=16bert_fit, batch_size=32bert_fit, batch_size=64stable_diffusion_fitgemma_fitmistral_fitbert_predictsam_predictbert_fit, batch_size=32bert_fit, batch_size=256stable_diffusion_fitstable_diffusion_fitbert_fit, batch_size=32* Given the lack of available 2xA100 machines on lambda.ai in these days I tried to experiment on different configurations/providers, but I would take those result with caution. I tried 2xH100 on lambda.ai and 2xA100 on vast.ai, but the batch sizes were probably not adapt to the machines, or the machines were running on very different conditions (eg. 250W limit for A100 on vast.ai)
** Statistically significant results
I will be happy to help to run any other benchmark and/or test necessary
Note: some of the code have been generated with coding agents, but all have been carefully reviewed and tested by me in person.
Contributor Agreement
Please review our AI-Assisted Contribution Policy and check all boxes below before submitting your PR for review:
Note: Failing to adhere to this agreement may result in your future PRs no longer being reviewed.