Skip to content

Added opt-in JAX threaded epoch iterator#22942

Open
AmedeoBiolatti wants to merge 7 commits into
keras-team:masterfrom
AmedeoBiolatti:feat/improved_jaxepochiterator
Open

Added opt-in JAX threaded epoch iterator#22942
AmedeoBiolatti wants to merge 7 commits into
keras-team:masterfrom
AmedeoBiolatti:feat/improved_jaxepochiterator

Conversation

@AmedeoBiolatti
Copy link
Copy Markdown
Contributor

Description

I have added an opt-in alternative to JAXEpochIterator, using a separate thread.
To enable it the env variable KERAS_JAX_EPOCH_ITERATOR=threaded should be set.

I have run some benchmark with some interesting results starting from the offical benchmarks
(the benchmark code can be found here: https://github.com/AmedeoBiolatti/keras-benchmarks/tree/jax_epoch_iterator_benchmark)
(wandb runs are also available if needed)

Benchmark Device Default (5 runs avg) Threaded (5 runs avg) Speedup
bert_fit, batch_size=8 A100-SXM4-40GB 54.23 53.35 1.62% **
bert_fit, batch_size=16 A100-SXM4-40GB 95.42 94.54 0.92% **
bert_fit, batch_size=32 A100-SXM4-40GB 188.28 186.47 0.96% **
bert_fit, batch_size=64 A100-SXM4-40GB 387.88 385.90 0.51% **
stable_diffusion_fit A100-SXM4-40GB 409.28 391.34 4.38% **
gemma_fit A100-SXM4-40GB 204.76 203.68 0.53% **
mistral_fit A100-SXM4-40GB 174.47 173.47 0.57% **
bert_predict A100-SXM4-40GB 423.81 423.55 0.06%
sam_predict A100-SXM4-40GB 468.31 431.57 7.85% **
bert_fit, batch_size=32 2xH100 80GB HBM3 * 49.69 49.67 0.04%
bert_fit, batch_size=256 2xH100 80GB HBM3 * 347.45 347.66 -0.06%
stable_diffusion_fit 2xH100 80GB HBM3 * 95.22 95.31 -0.09%
stable_diffusion_fit H100 80GB HBM3 * 27.09 26.92 0.63%
bert_fit, batch_size=32 2xA100-PCIE-40GB, 250W * 446.79 439.90 1.54%

* Given the lack of available 2xA100 machines on lambda.ai in these days I tried to experiment on different configurations/providers, but I would take those result with caution. I tried 2xH100 on lambda.ai and 2xA100 on vast.ai, but the batch sizes were probably not adapt to the machines, or the machines were running on very different conditions (eg. 250W limit for A100 on vast.ai)
** Statistically significant results

I will be happy to help to run any other benchmark and/or test necessary

Note: some of the code have been generated with coding agents, but all have been carefully reviewed and tested by me in person.

Contributor Agreement

Please review our AI-Assisted Contribution Policy and check all boxes below before submitting your PR for review:

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Note: Failing to adhere to this agreement may result in your future PRs no longer being reviewed.

@github-actions github-actions Bot added the Gemma Gemma model specific issues label May 19, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a threaded prefetch iterator for JAX-based training, enabling asynchronous data loading to improve performance. It replaces direct JAXEpochIterator instantiation with a build_jax_epoch_iterator factory and adds necessary resource management through close() methods. The review feedback highlights improvements for PEP 8 compliance, such as moving imports to the top level, cleaning up redundant instance attributes, addressing a potential race condition in the close() method, and improving the clarity of error messages.

Comment thread keras/src/backend/jax/trainer.py Outdated
Comment thread keras/src/backend/jax/trainer.py Outdated
Comment thread keras/src/backend/jax/trainer.py Outdated
Comment thread keras/src/backend/jax/trainer.py Outdated
Comment thread keras/src/backend/jax/trainer.py Outdated
Comment thread keras/src/backend/jax/trainer.py Outdated
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 83.97%. Comparing base (8f09b27) to head (4d15537).
⚠️ Report is 5 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22942      +/-   ##
==========================================
- Coverage   84.63%   83.97%   -0.67%     
==========================================
  Files         464      464              
  Lines       68245    68346     +101     
  Branches    11186    11202      +16     
==========================================
- Hits        57759    57391     -368     
- Misses       7558     8046     +488     
+ Partials     2928     2909      -19     
Flag Coverage Δ
keras 83.78% <100.00%> (-0.65%) ⬇️
keras-cpu 83.78% <100.00%> (+0.04%) ⬆️
keras-gpu ?
keras-jax 57.98% <100.00%> (-0.25%) ⬇️
keras-numpy 53.60% <18.64%> (+0.04%) ⬆️
keras-openvino 59.41% <18.64%> (+0.02%) ⬆️
keras-tensorflow 59.28% <18.64%> (-0.26%) ⬇️
keras-torch 58.48% <18.64%> (-0.33%) ⬇️
keras-tpu ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AmedeoBiolatti
Copy link
Copy Markdown
Contributor Author

I added tests to improve the coverage. The issue with the unit tests was probably due to a kubernetes timeout and solve simply rerunning the checks.

Let me know if there anything to change

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AmedeoBiolatti

Thanks for your interest in contributing to Keras.

First, I have a couple of questions about the result and methodology:

  • I have a hard time believing that the 0.51% improvement, to take one example, is statistically significant. When I try to benchmark anything, I get much bigger variations.
  • My understanding of the benchmarking code is that the dataset is just a single string that is split. I don't think this is a realistic scenario to optimize for. Realistic datasets are large. Also, in many cases the dataset are numerical, not strings.

@hertschuh hertschuh added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer Gemma Gemma model specific issues labels Jun 2, 2026
@AmedeoBiolatti
Copy link
Copy Markdown
Contributor Author

AmedeoBiolatti commented Jun 2, 2026

@AmedeoBiolatti

Thanks for your interest in contributing to Keras.

First, I have a couple of questions about the result and methodology:

  • I have a hard time believing that the 0.51% improvement, to take one example, is statistically significant. When I try to benchmark anything, I get much bigger variations.
  • My understanding of the benchmarking code is that the dataset is just a single string that is split. I don't think this is a realistic scenario to optimize for. Realistic datasets are large. Also, in many cases the dataset are numerical, not strings.

Hi @hertschuh,

  • I've made a couple changes to the benchmarking to reduce noise: I use the first epoch as warmup and I ignore timing of the first and last 10% of steps in the second epoch. This way the timing is pretty consistent. In the 0.51% improvement case the 5-reps worst case of the threaded iterator is faster than the 5-reps best case of the default iterator. You can find the wandb project here: https://wandb.ai/abio93/keras_jax_iterator. Running a 2-tailed t-test I get 0.01% p-value
  • I'd be happy to run other tests, these were just some of the ones in the official keras benchmark suite

As soon as I have some time I can re-run the profiling to show the difference in the traces

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants