Revisit the symbolic shape limitation on threefry_2x32 in terms of jax.export #24144
Open
2 tasks done
Labels
enhancement
New feature or request
Currently, the following snippet could not work as expected, raising
jax.random functions have limited support for shape polymorphism. In particular, the product of the known dimensions must be even.
However, according to git blame, this error was originally targeting
jax2tf
. I would like to know if this is still the case for StableHLO lowering. I understand that this involves certain degree of shape polymorphism, but I managed to express it in a way that workarounds such limitation (to some extent):And test with
The core complexity lies in that currently JAX's symbolic system could not recognize
mod(B, 2) + B
as an even number, so I introduced an extra dummy symbolC
to workaround it. I believe that JAX may already have some sort of polymorphic shape assertion or override mechanism internally (e.g., leveraging the fact thatout
andcount
must have same shape, and temporarily cast the shape to some simple form for the computation in between), which could further simplify the code.Please:
The text was updated successfully, but these errors were encountered: