Skip to content

Commit 00601f0

Browse files
committed
Change introduction to randomness
1 parent e300469 commit 00601f0

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

docs_nnx/guides/randomness.ipynb

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,7 @@
66
"source": [
77
"# Randomness\n",
88
"\n",
9-
"Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX _defines the random state as an object state_. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in `nnx.Variable`s; and 3) held by the models themselves.\n",
10-
"\n",
11-
"The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n",
12-
"\n",
13-
"- It is **explicit**.\n",
14-
"- It is **order-based**.\n",
15-
"- It uses **dynamic counters**.\n",
16-
"\n",
17-
"This is a bit different from [Flax Linen's PRNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html), which is `(path + order)`-based, and uses static counters.\n",
18-
"\n",
19-
"> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n",
20-
"\n",
21-
"Let’ start with some necessary imports:"
9+
"Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout:"
2210
]
2311
},
2412
{
@@ -28,14 +16,35 @@
2816
"outputs": [],
2917
"source": [
3018
"from flax import nnx\n",
31-
"import jax\n",
32-
"from jax import random, numpy as jnp"
19+
"\n",
20+
"class Model(nnx.Module):\n",
21+
" def __init__(self, *, rngs: nnx.Rngs):\n",
22+
" self.linear = nnx.Linear(20, 10, rngs=rngs)\n",
23+
" self.drop = nnx.Dropout(0.1)\n",
24+
"\n",
25+
" def __call__(self, x, *, rngs):\n",
26+
" return nnx.relu(self.drop(self.linear(x), rngs=rngs))\n",
27+
"\n",
28+
"rngs = nnx.Rngs(0)\n",
29+
"model = Model(rngs=rngs) # pass rngs to initialize parameters\n",
30+
"x = rngs.normal((32, 20)) # convenient jax.random methods\n",
31+
"y = model(x, rngs=rngs) # pass rngs for dropout masks"
3332
]
3433
},
3534
{
3635
"cell_type": "markdown",
3736
"metadata": {},
3837
"source": [
38+
"We always pass `nnx.Rngs` objects to models at initialization (to initialize parameters). For models with nondeterministic outputs like the one above, we also pass `nnx.Rngs` objects to the model's `__call__` method.\n",
39+
"\n",
40+
"The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n",
41+
"\n",
42+
"- It is **explicit**.\n",
43+
"- It is **order-based**.\n",
44+
"- It uses **dynamic counters**.\n",
45+
"\n",
46+
"> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n",
47+
"\n",
3948
"## `Rngs`, `RngStream`, and `RngState`\n",
4049
"\n",
4150
"In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n",
@@ -199,6 +208,7 @@
199208
}
200209
],
201210
"source": [
211+
"import jax.numpy as jnp\n",
202212
"dropout(jnp.ones(4), rngs=rngs)"
203213
]
204214
},
@@ -295,6 +305,7 @@
295305
"metadata": {},
296306
"outputs": [],
297307
"source": [
308+
"import jax\n",
298309
"rngs = nnx.Rngs(0, params=1)\n",
299310
"\n",
300311
"# using jax.random\n",
@@ -539,7 +550,7 @@
539550
"\n",
540551
"In Flax NNX, there are two ways to approach this:\n",
541552
"\n",
542-
"1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously. \n",
553+
"1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.\n",
543554
"2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.\n",
544555
"\n",
545556
"`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.\n",

docs_nnx/guides/randomness.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,35 @@ jupytext:
1010

1111
# Randomness
1212

13-
Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX _defines the random state as an object state_. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in `nnx.Variable`s; and 3) held by the models themselves.
13+
Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout:
14+
15+
```{code-cell} ipython3
16+
from flax import nnx
17+
18+
class Model(nnx.Module):
19+
def __init__(self, *, rngs: nnx.Rngs):
20+
self.linear = nnx.Linear(20, 10, rngs=rngs)
21+
self.drop = nnx.Dropout(0.1)
22+
23+
def __call__(self, x, *, rngs):
24+
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
25+
26+
rngs = nnx.Rngs(0)
27+
model = Model(rngs=rngs) # pass rngs to initialize parameters
28+
x = rngs.normal((32, 20)) # convenient jax.random methods
29+
y = model(x, rngs=rngs) # pass rngs for dropout masks
30+
```
31+
32+
We always pass `nnx.Rngs` objects to models at initialization (to initialize parameters). For models with nondeterministic outputs like the one above, we also pass `nnx.Rngs` objects to the model's `__call__` method.
1433

1534
The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:
1635

1736
- It is **explicit**.
1837
- It is **order-based**.
1938
- It uses **dynamic counters**.
2039

21-
This is a bit different from [Flax Linen's PRNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html), which is `(path + order)`-based, and uses static counters.
22-
2340
> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).
2441
25-
Let’ start with some necessary imports:
26-
27-
```{code-cell} ipython3
28-
from flax import nnx
29-
import jax
30-
from jax import random, numpy as jnp
31-
```
32-
3342
## `Rngs`, `RngStream`, and `RngState`
3443

3544
In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).
@@ -85,6 +94,7 @@ dropout = nnx.Dropout(0.5)
8594
```
8695

8796
```{code-cell} ipython3
97+
import jax.numpy as jnp
8898
dropout(jnp.ones(4), rngs=rngs)
8999
```
90100

@@ -127,6 +137,7 @@ As shown above, a PRNG key from the `default` stream can also be generated by ca
127137
Since a very common pattern is to sample a key and immediately pass it to a function from `jax.random`, both `Rngs` and `RngStream` expose the same functions as methods with the same signature except they don't require a key:
128138

129139
```{code-cell} ipython3
140+
import jax
130141
rngs = nnx.Rngs(0, params=1)
131142
132143
# using jax.random
@@ -224,7 +235,7 @@ In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` e
224235

225236
In Flax NNX, there are two ways to approach this:
226237

227-
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
238+
1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.
228239
2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.
229240

230241
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.

0 commit comments

Comments
 (0)