Skip to content

Commit 740945a

Browse files
superbobryjax authors
authored andcommitted
Moved the implementation of custom_partitioning into jax/_src
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650201550
1 parent e334770 commit 740945a

File tree

3 files changed

+561
-537
lines changed

3 files changed

+561
-537
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ py_library_providing_imports_info(
209209
"_src/checkify.py",
210210
"_src/custom_batching.py",
211211
"_src/custom_derivatives.py",
212+
"_src/custom_partitioning.py",
212213
"_src/custom_transpose.py",
213214
"_src/debugging.py",
214215
"_src/dispatch.py",

0 commit comments

Comments
 (0)