|
6 | 6 | "source": [ |
7 | 7 | "# Randomness\n", |
8 | 8 | "\n", |
9 | | - "Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX _defines the random state as an object state_. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in `nnx.Variable`s; and 3) held by the models themselves.\n", |
10 | | - "\n", |
11 | | - "The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n", |
12 | | - "\n", |
13 | | - "- It is **explicit**.\n", |
14 | | - "- It is **order-based**.\n", |
15 | | - "- It uses **dynamic counters**.\n", |
16 | | - "\n", |
17 | | - "This is a bit different from [Flax Linen's PRNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html), which is `(path + order)`-based, and uses static counters.\n", |
18 | | - "\n", |
19 | | - "> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n", |
20 | | - "\n", |
21 | | - "Let’ start with some necessary imports:" |
| 9 | + "Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout:" |
22 | 10 | ] |
23 | 11 | }, |
24 | 12 | { |
|
28 | 16 | "outputs": [], |
29 | 17 | "source": [ |
30 | 18 | "from flax import nnx\n", |
31 | | - "import jax\n", |
32 | | - "from jax import random, numpy as jnp" |
| 19 | + "\n", |
| 20 | + "class Model(nnx.Module):\n", |
| 21 | + " def __init__(self, *, rngs: nnx.Rngs):\n", |
| 22 | + " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", |
| 23 | + " self.drop = nnx.Dropout(0.1)\n", |
| 24 | + "\n", |
| 25 | + " def __call__(self, x, *, rngs):\n", |
| 26 | + " return nnx.relu(self.drop(self.linear(x), rngs=rngs))\n", |
| 27 | + "\n", |
| 28 | + "rngs = nnx.Rngs(0)\n", |
| 29 | + "model = Model(rngs=rngs) # pass rngs to initialize parameters\n", |
| 30 | + "x = rngs.normal((32, 20)) # convenient jax.random methods\n", |
| 31 | + "y = model(x, rngs=rngs) # pass rngs for dropout masks" |
33 | 32 | ] |
34 | 33 | }, |
35 | 34 | { |
36 | 35 | "cell_type": "markdown", |
37 | 36 | "metadata": {}, |
38 | 37 | "source": [ |
| 38 | + "We always pass `nnx.Rngs` objects to models at initialization (to initialize parameters). For models with nondeterministic outputs like the one above, we also pass `nnx.Rngs` objects to the model's `__call__` method.\n", |
| 39 | + "\n", |
| 40 | + "The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n", |
| 41 | + "\n", |
| 42 | + "- It is **explicit**.\n", |
| 43 | + "- It is **order-based**.\n", |
| 44 | + "- It uses **dynamic counters**.\n", |
| 45 | + "\n", |
| 46 | + "> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n", |
| 47 | + "\n", |
39 | 48 | "## `Rngs`, `RngStream`, and `RngState`\n", |
40 | 49 | "\n", |
41 | 50 | "In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", |
|
199 | 208 | } |
200 | 209 | ], |
201 | 210 | "source": [ |
| 211 | + "import jax.numpy as jnp\n", |
202 | 212 | "dropout(jnp.ones(4), rngs=rngs)" |
203 | 213 | ] |
204 | 214 | }, |
|
295 | 305 | "metadata": {}, |
296 | 306 | "outputs": [], |
297 | 307 | "source": [ |
| 308 | + "import jax\n", |
298 | 309 | "rngs = nnx.Rngs(0, params=1)\n", |
299 | 310 | "\n", |
300 | 311 | "# using jax.random\n", |
|
539 | 550 | "\n", |
540 | 551 | "In Flax NNX, there are two ways to approach this:\n", |
541 | 552 | "\n", |
542 | | - "1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously. \n", |
| 553 | + "1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.\n", |
543 | 554 | "2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.\n", |
544 | 555 | "\n", |
545 | 556 | "`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.\n", |
|
0 commit comments