Skip to content

Experiment: mutable RNG object #28845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented May 20, 2025

An experiment with implicitly-updated PRNG keys in JAX, based on the MutableArray experiment. The API is inspired by numpy.random.default_rng.

The mutable array code is still in development, with a few blockers to being able to rely on it for non-experimental APIs, but I thought this would be a useful exploration in the meantime!

Example:

In [1]: from jax.experimental.mutable_rng import default_rng

In [2]: rng = default_rng(123456)

# Calling rng.key() repeatedly will generate new keys each time, implicitly updating
# the internal state:
In [3]: rng.key()
Out[3]: 
Array((), dtype=key<fry>) overlaying:
[2376306259 4203193672]

In [4]: rng.key()
Out[4]: 
Array((), dtype=key<fry>) overlaying:
[ 989706967 1551494763]

# The rng object provides methods to generate values based on these implicitly-updated
# keys, inspired by the api of np.random.default_rng()
In [5]: rng.uniform()
Out[5]: Array(0.6851474, dtype=float32)

In [6]: rng.uniform()
Out[6]: Array(0.70181704, dtype=float32)

In [7]: rng.integers(0, 10, size=5)
Out[7]: Array([8, 0, 5, 7, 2], dtype=int32)

In [8]: rng.integers(0, 10, size=5)
Out[8]: Array([9, 0, 6, 7, 1], dtype=int32)

# Even when used under JIT, the implicit update propagates through due to being
# backed by the experimental MutableArray object.
In [9]: import jax

In [10]: @jax.jit
    ...: def f(rng):
    ...:   return rng.integers(0, 10, size=5)
    ...: 

In [11]: f(rng)
Out[11]: Array([7, 1, 0, 8, 1], dtype=int32)

In [12]: f(rng)
Out[12]: Array([3, 9, 2, 5, 6], dtype=int32)

@jakevdp jakevdp marked this pull request as draft May 20, 2025 13:43
@jakevdp jakevdp self-assigned this May 20, 2025
@jakevdp jakevdp changed the title Experiment: mutable RNG state Experiment: mutable RNG object May 20, 2025
@jakevdp jakevdp requested review from mattjj and dougalm May 20, 2025 14:07
@jakevdp jakevdp force-pushed the mutable-random branch 3 times, most recently from 5b0e4ce to 1728917 Compare May 20, 2025 15:35
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label May 20, 2025
@jakevdp jakevdp force-pushed the mutable-random branch 3 times, most recently from 8f3dcdc to f3a20f6 Compare May 20, 2025 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant