Skip to content

Experiment: mutable RNG object #28845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1415,3 +1415,16 @@ pytype_library(
":jax",
],
)

pytype_library(
name = "experimental_mutable_rng",
srcs = [
"experimental/mutable_rng.py",
],
deps = [
":core",
":jax",
":tree_util",
":typing",
] + py_deps("numpy"),
)
211 changes: 211 additions & 0 deletions jax/experimental/mutable_rng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Experimental implicitly-updated PRNG, based on MutableArray.
"""

import dataclasses
import operator
from typing import Sequence

import jax.numpy as jnp
from jax._src import core
from jax._src import dtypes
from jax._src import random
from jax._src import tree_util
from jax._src import typing

from jax._src.typing import Array, ArrayLike, DTypeLike

import numpy as np


def _canonicalize_size(size: int | Sequence[int] | None, *args: ArrayLike) -> tuple[int, ...]:
if size is None:
return np.broadcast_shapes(*(np.shape(arg) for arg in args))
elif isinstance(size, int):
return (size,)
else:
return tuple(map(operator.index, size))


@tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class MutablePRNG:
"""Implicit PRNG backed by MutableArray.

This should be instantiated using the :func:`default_rng` function.

Attributes:
base_key: a typed JAX PRNG key object (see :func:`jax.random.key`).
counter: a scalar integer wrapped in a :class:`MutableArray`

>>> from jax.experimental.mutable_rng import default_rng
>>> rng = default_rng(42)
>>> rng
MutablePRNG(base_key=Array((), dtype=key<fry>) overlaying:
[ 0 42], counter=MutableArray(0, dtype=int32, weak_type=True))
"""
base_key: Array
counter: core.MutableArray

def __post_init__(self):
if not (isinstance(self.base_key, Array)
and dtypes.issubdtype(self.base_key.dtype, dtypes.prng_key)):
raise ValueError(f"Expected base_key to be a typed PRNG key; got {self.base_key}")
# TODO(jakevdp): how to validate a traced mutable array?
if not (isinstance(self.counter, (core.MutableArray, core.Tracer))
and self.counter.shape == ()
and dtypes.issubdtype(self.counter.dtype, jnp.integer)):
raise ValueError(f"Expected counter to be a mutable scalar integer; got {self.counter}")

def key(self) -> Array:
"""Generate a new JAX PRNGKey, updating the internal state.

Returns:
A new, independent PRNG key with the same impl/dtype as
``self.base_key``.

Examples:
>>> from jax.experimental.mutable_rng import default_rng
>>> rng = default_rng(0)
>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[1797259609 2579123966]
>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[ 928981903 3453687069]
"""
key = random.fold_in(self.base_key, self.counter[...])
self.counter[...] += 1
return key

def random(
self,
size: int | Sequence[int] | None = None,
dtype: DTypeLike = float,
):
"""Return random floats in the half-open interval [0.0, 1.0)."""
# TODO(jakevdp): write docstring
return random.uniform(self.key(), shape=_canonicalize_size(size), dtype=dtype)


def uniform(
self,
low: ArrayLike = 0,
high: ArrayLike = 1,
size: int | Sequence[int] | None = None,
*,
dtype: DTypeLike = float,
) -> Array:
"""Draw uniformly distributed pseudorandom values."""
# TODO(jakevdp): write docstring
return random.uniform(self.key(), _canonicalize_size(size, low, high),
minval=low, maxval=high, dtype=dtype)

def normal(
self,
loc: ArrayLike = 0,
scale: ArrayLike = 1,
size: int | Sequence[int] | None = None,
*,
dtype: DTypeLike = float,
) -> Array:
"""Draw normally-distributed pseudorandom values."""
# TODO(jakevdp): write docstring
norm = random.normal(self.key(), _canonicalize_size(size, loc, scale), dtype)
return (jnp.asarray(loc) + jnp.asarray(scale) * norm).astype(dtype)

def integers(
self,
low: ArrayLike,
high: ArrayLike | None = None,
size: int | Sequence[int] | None = None,
*,
dtype: DTypeLike = int,
) -> Array:
"""Draw pseudorandom integers."""
# TODO(jakevdp): write docstring
if high is None:
low, high = 0, low
return random.randint(self.key(), _canonicalize_size(size, low, high),
minval=low, maxval=high, dtype=dtype)

def spawn(self, n_children: int) -> list['MutablePRNG']:
"""Create new independent child generators.

Args:
n_children: non-negative integer.

Returns:
A list of length ``n_children`` containing new independent ``MutablePRNG`` instances
spawned from the original instance.

Examples:
>>> from jax.experimental.mutable_rng import default_rng
>>> rng = default_rng(123)
>>> child_rngs = rng.spawn(2)
>>> [crng.integers(0, 10, 2) for crng in child_rngs]
[Array([1, 3], dtype=int32), Array([9, 9], dtype=int32)]
"""
return [self.__class__(self.key(), core.mutable_array(0)) for _ in range(n_children)]


def default_rng(seed: typing.ArrayLike, *,
impl: random.PRNGSpecDesc | None = None) -> MutablePRNG:
"""
Implicitly updated PRNG API.

Args:
seed: a 64- or 32-bit integer used as the value of the key.
impl: optional string specifying the PRNG implementation (e.g.
``'threefry2x32'``)

Returns:
A MutablePRNG object, with methods for generating random values.

Examples:
>>> from jax.experimental.mutable_rng import default_rng
>>> rng = default_rng(42)

Repeated draws implicitly update the key:

>>> rng.uniform()
Array(0.5302608, dtype=float32)
>>> rng.uniform()
Array(0.72766423, dtype=float32)

This also works under transformations like :func:`jax.jit`:

>>> import jax
>>> jit_uniform = jax.jit(rng.uniform)
>>> jit_uniform()
Array(0.6672406, dtype=float32)
>>> jit_uniform()
Array(0.3890121, dtype=float32)

Keys can be generated directly if desired:

>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[2954079971 3276725750]
>>> rng.key()
Array((), dtype=key<fry>) overlaying:
[2765691542 824333390]
"""
return MutablePRNG(
base_key=random.key(seed, impl=impl),
counter=core.mutable_array(0)
)
11 changes: 11 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,17 @@ jax_multiplatform_test(
]),
)

jax_multiplatform_test(
name = "mutable_rng_test",
srcs = ["mutable_rng_test.py"],
deps = [
"//jax:experimental_mutable_rng",
] + py_deps([
"absl/testing",
"numpy",
]),
)

jax_multiplatform_test(
name = "for_loop_test",
srcs = ["for_loop_test.py"],
Expand Down
147 changes: 147 additions & 0 deletions tests/mutable_rng_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import numpy as np

import jax
from jax.experimental.mutable_rng import default_rng, MutablePRNG
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu


config.parse_flags_with_absl()

class MutableRNGTest(jtu.JaxTestCase):
def test_mutable_rng_instantiation(self, seed=547389):
rng = default_rng(seed)
key = jax.random.key(seed)

self.assertEqual(key, rng.base_key)
self.assertEqual(rng.counter.shape, ())
self.assertEqual(0, rng.counter[...])

_ = rng.key()
self.assertEqual(key, rng.base_key)
self.assertEqual(rng.counter.shape, ())
self.assertEqual(1, rng.counter[...])

def test_mutable_rng_invalid_instantiation(self):
valid_key = jax.random.key(0)
valid_counter = core.mutable_array(0)
invalid_key = jax.numpy.array([0, 1], dtype='uint32')
invalid_counter = 0
with self.assertRaisesRegex(ValueError, "Expected base_key to be a typed PRNG key"):
MutablePRNG(invalid_key, valid_counter)
with self.assertRaisesRegex(ValueError, "Expected counter to be a mutable scalar integer"):
MutablePRNG(valid_key, invalid_counter)

def testRepeatedKeys(self, seed=578543):
prng = default_rng(seed)
self.assertNotEqual(prng.key(), prng.key())

def testRepeatedDraws(self, seed=328090):
prng = default_rng(seed)
vals1 = prng.uniform(size=10)
vals2 = prng.uniform(size=10)
self.assertTrue((vals1 != vals2).all())

def testRepeatedDrawsJIT(self, seed=328090):
prng = default_rng(seed)
@jax.jit
def get_values(prng):
return prng.uniform(size=10)
vals1 = get_values(prng)
vals2 = get_values(prng)
self.assertTrue((vals1 != vals2).all())

@jtu.sample_product(
size=[None, 2, (5, 2)],
dtype=jtu.dtypes.floating,
)
def testRandom(self, size, dtype):
rng = default_rng(578943)
vals = rng.random(size, dtype)
shape = np.broadcast_shapes(size or ())

self.assertEqual(vals.shape, shape)
self.assertEqual(vals.dtype, dtype)
self.assertTrue((vals < 1).all())
self.assertTrue((vals >= 0).all())

@jtu.sample_product(
low=[0, 1, np.array([0, 1])],
high=[2, 3, np.array([2, 3])],
size=[None, 2, (5, 2)],
dtype=jtu.dtypes.floating,
)
@jax.numpy_dtype_promotion('standard')
@jax.numpy_rank_promotion('allow')
def testUniform(self, low, high, size, dtype):
rng = default_rng(473289)
vals = rng.uniform(low, high, size, dtype=dtype)
shape = np.broadcast_shapes(np.shape(low), np.shape(high), size or ())

self.assertEqual(vals.shape, shape)
self.assertEqual(vals.dtype, dtype)
self.assertTrue((vals < high).all())
self.assertTrue((vals >= low).all())

@jtu.sample_product(
loc=[0, 1, np.array([0, 1])],
scale=[2, 3, np.array([2, 3])],
size=[None, 2, (5, 2)],
dtype=jtu.dtypes.floating,
)
@jax.numpy_dtype_promotion('standard')
@jax.numpy_rank_promotion('allow')
def testNormal(self, loc, scale, size, dtype):
rng = default_rng(473289)
vals = rng.normal(loc, scale, size, dtype=dtype)
shape = np.broadcast_shapes(np.shape(loc), np.shape(scale), size or ())

self.assertEqual(vals.shape, shape)
self.assertEqual(vals.dtype, dtype)

@jtu.sample_product(
low=[0, 1, np.array([0, 1])],
high=[10, 15, np.array([10, 15])],
size=[None, 2, (5, 2)],
dtype=jtu.dtypes.integer,
)
@jax.numpy_dtype_promotion('standard')
@jax.numpy_rank_promotion('allow')
def testIntegers(self, low, high, size, dtype):
rng = default_rng(473289)
vals = rng.integers(low, high, size, dtype=dtype)
shape = np.broadcast_shapes(np.shape(low), np.shape(high), size or ())

self.assertEqual(vals.shape, shape)
self.assertEqual(vals.dtype, dtype)
self.assertTrue((vals < high).all())
self.assertTrue((vals >= low).all())

def testSpawn(self):
rng = default_rng(758943)
rngs = rng.spawn(4)

for child_rng in rngs:
self.assertNotEqual(rng.base_key, child_rng.base_key)
self.assertEqual(0, child_rng.counter[...])


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading