Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit the symbolic shape limitation on threefry_2x32 in terms of jax.export #24144

Open
2 tasks done
jjyyxx opened this issue Oct 5, 2024 · 0 comments
Open
2 tasks done
Labels
enhancement New feature or request

Comments

@jjyyxx
Copy link
Contributor

jjyyxx commented Oct 5, 2024

Currently, the following snippet could not work as expected, raising jax.random functions have limited support for shape polymorphism. In particular, the product of the known dimensions must be even.

import jax, jax.export

@jax.jit
def f(dummy):
    key = jax.random.key(0)
    return jax.random.normal(key, dummy.shape)

dummy = jax.ShapeDtypeStruct(
    shape=jax.export.symbolic_shape("B"),
    dtype=jax.numpy.float32
)

jax.export.export(f)(dummy)

However, according to git blame, this error was originally targeting jax2tf. I would like to know if this is still the case for StableHLO lowering. I understand that this involves certain degree of shape polymorphism, but I managed to express it in a way that workarounds such limitation (to some extent):

import jax, jax.numpy as jnp, jax.lax as lax, jax.export
import numpy as np
from jax.extend.random import threefry2x32_p

B, C = jax.export.symbolic_shape("B,C", constraints=["2*C == mod(17*B, 2) + 17*B"])

def threefry_2x32(keypair, count):
  """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
    count: an array of dtype uint32 used for the counts.

  Returns:
    An array of dtype uint32 with the same shape as `count`.
  """
  key1, key2 = keypair
  if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
    msg = "threefry_2x32 requires uint32 arguments, got {}"
    raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))

  odd_size = count.size % 2
  x = jnp.concatenate([count, jnp.zeros(odd_size, dtype=jnp.uint32)])
  x = x.reshape(2, C)
  x = threefry2x32_p.bind(key1, key2, x[0], x[1])
  out = jnp.concatenate(x).reshape(C * 2)
  assert out.dtype == np.uint32
  return lax.reshape(out[:C * 2 - odd_size], count.shape)

import jax._src.prng
jax._src.prng.threefry_2x32 = threefry_2x32

And test with

@jax.jit
def f(key, x, _dummy):
    return jax.random.normal(key, x.shape)

key = jax.random.PRNGKey(0)
x = jax.ShapeDtypeStruct((B, 17), jnp.float32)
dummy = jax.ShapeDtypeStruct((C,), jnp.float32)

e = jax.export.export(f)(key, x, dummy)

x = jnp.zeros((5, 17), jnp.float32)
dummy = jnp.zeros(((x.size + x.size%2)//2,), jnp.float32)
e.call(key, x, dummy)

x = jnp.zeros((6, 17), jnp.float32)
dummy = jnp.zeros(((x.size + x.size%2)//2,), jnp.float32)
e.call(key, x, dummy)

The core complexity lies in that currently JAX's symbolic system could not recognize mod(B, 2) + B as an even number, so I introduced an extra dummy symbol C to workaround it. I believe that JAX may already have some sort of polymorphic shape assertion or override mechanism internally (e.g., leveraging the fact that out and count must have same shape, and temporarily cast the shape to some simple form for the computation in between), which could further simplify the code.

Please:

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.
@jjyyxx jjyyxx added the enhancement New feature or request label Oct 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant