Skip to content

Conversation

@cantonios
Copy link
Contributor

If a variable has a particular layout (e.g. is sharded across multiple devices/hosts), then any corresponding auxiliary variables, like optimizer gradient accumulators, need to have the same layout. This requires use to set the variable layout prior to initialization so that it is initialized correctly and efficiently across devices.

Added optional kwargs to the base Variable class so they can handle (and ignore) any backend-specific options. Modified JaxVariable to allow setting the layout parameter on construction. Modified add_variable_from_reference to copy the layout from the reference variable.

If a variable has a particular layout (e.g. is sharded across
multiple devices/hosts), then any corresponding auxiliary variables,
like optimizer gradient accumulators, need to have the same layout.
This requires use to set the variable layout prior to initialization
so that it is initialized correctly and efficiently across devices.

Added optional `kwargs` to the base `Variable` class so they can
handle (and ignore) any backend-specific options.  Modified
`JaxVariable` to allow setting the layout parameter on construction.
Modified `add_variable_from_reference` to copy the layout from
the reference variable.
@codecov-commenter
Copy link

codecov-commenter commented May 7, 2025

Codecov Report

Attention: Patch coverage is 81.81818% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.54%. Comparing base (f98b91f) to head (142ace2).

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 80.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21264      +/-   ##
==========================================
- Coverage   82.55%   82.54%   -0.01%     
==========================================
  Files         564      564              
  Lines       54629    54636       +7     
  Branches     8493     8494       +1     
==========================================
+ Hits        45097    45102       +5     
- Misses       7442     7443       +1     
- Partials     2090     2091       +1     
Flag Coverage Δ
keras 82.36% <81.81%> (-0.01%) ⬇️
keras-jax 63.63% <81.81%> (+<0.01%) ⬆️
keras-numpy 58.75% <18.18%> (-0.01%) ⬇️
keras-openvino 32.96% <18.18%> (-0.01%) ⬇️
keras-tensorflow 64.04% <18.18%> (-0.01%) ⬇️
keras-torch 63.69% <18.18%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels May 8, 2025
@fchollet fchollet merged commit cbb3682 into keras-team:master May 8, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run ready to pull Ready to be merged into the codebase size:S

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants