Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions docs_nnx/guides/bridge_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"from flax.nnx import bridge\n",
"import jax\n",
"from jax import numpy as jnp\n",
"from jax.experimental import mesh_utils\n",
"from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n",
"from typing import *"
]
},
Expand Down Expand Up @@ -638,7 +638,7 @@
"\n",
"In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n",
"\n",
"The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX).\n",
"The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen).\n",
"\n",
"### Linen to NNX\n",
"\n",
Expand Down Expand Up @@ -686,15 +686,14 @@
"\n",
"\n",
"print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n",
"mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n",
" axis_names=('in', 'out'))\n",
"mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))\n",
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
"with mesh:\n",
"with jax.set_mesh(mesh):\n",
" model = create_sharded_nnx_module(x)\n",
"\n",
"print(type(model.w)) # `nnx.Param`\n",
"print(model.w.sharding) # The partition annotation attached with `w`\n",
"print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh"
"print(type(model.w)) # `nnx.Param`\n",
"print(model.w.sharding) # The partition annotation attached with `w`\n",
"print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh"
]
},
{
Expand All @@ -703,9 +702,9 @@
"metadata": {},
"source": [
" We have 8 fake JAX devices now to partition this model...\n",
" <class 'flax.nnx.variables.Param'>\n",
" ('in', 'out')\n",
" GSPMDSharding({devices=[2,4]<=[8]})"
" <class 'flax.nnx.variablelib.Param'>\n",
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)\n",
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)"
]
},
{
Expand Down Expand Up @@ -737,8 +736,9 @@
"source": [
"class NNXDotWithParititioning(nnx.Module):\n",
" def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n",
" init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))\n",
" self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))\n",
" init_fn = nnx.initializers.lecun_normal()\n",
" self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),\n",
" sharding_names=('in', 'out'))\n",
" def __call__(self, x: jax.Array):\n",
" return x @ self.w\n",
"\n",
Expand All @@ -751,7 +751,7 @@
" # A `NNXMeta` wrapper of the underlying `nnx.Param`\n",
" assert type(variables['params']['w']) == bridge.NNXMeta\n",
" # The annotation coming from the `nnx.Param` => (in, out)\n",
" assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n",
" assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')\n",
"\n",
" unboxed_variables = nn.unbox(variables)\n",
" variable_pspecs = nn.get_partition_spec(variables)\n",
Expand All @@ -763,7 +763,7 @@
" nn.get_partition_spec(variables))\n",
" return sharded_vars\n",
"\n",
"with mesh:\n",
"with jax.set_mesh(mesh):\n",
" variables = create_sharded_variables(jax.random.key(0), x)\n",
"\n",
"# The underlying JAX array is sharded across the 2x4 mesh\n",
Expand All @@ -774,7 +774,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
" GSPMDSharding({devices=[2,4]<=[8]})"
" NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)"
]
},
{
Expand Down
32 changes: 16 additions & 16 deletions docs_nnx/guides/bridge_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
from typing import *
```

Expand Down Expand Up @@ -336,7 +336,7 @@ Flax uses a metadata wrapper box over the raw JAX array to annotate how a variab

In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.

The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX).
The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen).

### Linen to NNX

Expand Down Expand Up @@ -367,21 +367,20 @@ def create_sharded_nnx_module(x):


print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
axis_names=('in', 'out'))
mesh = jax.make_mesh((2, 4), ('in', 'out'), axis_types=(AxisType.Auto, AxisType.Auto))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
with jax.set_mesh(mesh):
model = create_sharded_nnx_module(x)

print(type(model.w)) # `nnx.Param`
print(model.w.sharding) # The partition annotation attached with `w`
print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh
print(type(model.w)) # `nnx.Param`
print(model.w.sharding) # The partition annotation attached with `w`
print(model.w.get_value().sharding) # The underlying JAX array is sharded across the 2x4 mesh
```

We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variables.Param'>
('in', 'out')
GSPMDSharding({devices=[2,4]<=[8]})
<class 'flax.nnx.variablelib.Param'>
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)

+++

Expand All @@ -396,8 +395,9 @@ Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the ra
```{code-cell} ipython3
class NNXDotWithParititioning(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
init_fn = nnx.initializers.lecun_normal()
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)),
sharding_names=('in', 'out'))
def __call__(self, x: jax.Array):
return x @ self.w

Expand All @@ -410,7 +410,7 @@ def create_sharded_variables(key, x):
# A `NNXMeta` wrapper of the underlying `nnx.Param`
assert type(variables['params']['w']) == bridge.NNXMeta
# The annotation coming from the `nnx.Param` => (in, out)
assert variables['params']['w'].metadata['sharding'] == ('in', 'out')
assert variables['params']['w'].metadata['sharding_names'] == ('in', 'out')

unboxed_variables = nn.unbox(variables)
variable_pspecs = nn.get_partition_spec(variables)
Expand All @@ -422,14 +422,14 @@ def create_sharded_variables(key, x):
nn.get_partition_spec(variables))
return sharded_vars

with mesh:
with jax.set_mesh(mesh):
variables = create_sharded_variables(jax.random.key(0), x)

# The underlying JAX array is sharded across the 2x4 mesh
print(variables['params']['w'].sharding)
```

GSPMDSharding({devices=[2,4]<=[8]})
NamedSharding(mesh=Mesh('in': 2, 'out': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('in', 'out'), memory_kind=device)

+++

Expand Down
21 changes: 12 additions & 9 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@
"outputs": [],
"source": [
"# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n",
"auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))"
"auto_mesh = jax.make_mesh(\n",
" (2, 4),\n",
" ('data', 'model'),\n",
" axis_types=(AxisType.Auto, AxisType.Auto),\n",
")"
]
},
{
Expand Down Expand Up @@ -203,7 +207,7 @@
"source": [
"### Initialize with style\n",
"\n",
"When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
"When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
"\n",
"Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out."
]
Expand All @@ -216,10 +220,9 @@
"source": [
"@jax.jit\n",
"def init_sharded_linear(key):\n",
" init_fn = nnx.nn.linear.default_kernel_init\n",
" # Shard your parameter along `model` dimension, as in model/tensor parallelism\n",
" return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n",
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n",
" kernel_metadata={'sharding_names': (None, 'model')})\n",
"\n",
"with jax.set_mesh(auto_mesh):\n",
" key= rngs()\n",
Expand Down Expand Up @@ -328,12 +331,12 @@
" init_fn = nnx.initializers.lecun_normal()\n",
" self.dot1 = nnx.Linear(\n",
" depth, depth,\n",
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n",
" use_bias=False, # or use `bias_init` to give it annotation too\n",
" kernel_metadata={'sharding_names': (None, 'model')},\n",
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
" rngs=rngs)\n",
" self.w2 = nnx.Param(\n",
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",
" sharding=('model', None),\n",
" sharding=('model', None), # same as sharding_names=('model', None)\n",
" )\n",
"\n",
" def __call__(self, x: jax.Array):\n",
Expand Down Expand Up @@ -512,8 +515,8 @@
" init_fn = nnx.initializers.lecun_normal()\n",
" self.dot1 = nnx.Linear(\n",
" depth, depth,\n",
" kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n",
" use_bias=False, # or use `bias_init` to give it annotation too\n",
" kernel_metadata={'sharding_names': ('embed', 'hidden')},\n",
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
" rngs=rngs)\n",
" self.w2 = nnx.Param(\n",
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",
Expand Down
21 changes: 12 additions & 9 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ In this guide we use a standard FSDP layout and shard our devices on two axes -

```{code-cell} ipython3
# Create an auto-mode mesh of two dimensions and annotate each axis with a name.
auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))
auto_mesh = jax.make_mesh(
(2, 4),
('data', 'model'),
axis_types=(AxisType.Auto, AxisType.Auto),
)
```

> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function.
Expand Down Expand Up @@ -89,17 +93,16 @@ with jax.set_mesh(auto_mesh):

### Initialize with style

When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.
When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.

Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out.

```{code-cell} ipython3
@jax.jit
def init_sharded_linear(key):
init_fn = nnx.nn.linear.default_kernel_init
# Shard your parameter along `model` dimension, as in model/tensor parallelism
return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))
kernel_metadata={'sharding_names': (None, 'model')})

with jax.set_mesh(auto_mesh):
key= rngs()
Expand Down Expand Up @@ -144,12 +147,12 @@ class DotReluDot(nnx.Module):
init_fn = nnx.initializers.lecun_normal()
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
use_bias=False, # or use `bias_init` to give it annotation too
kernel_metadata={'sharding_names': (None, 'model')},
use_bias=False, # or use `bias_metadata` to give it annotation too
rngs=rngs)
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
sharding=('model', None),
sharding=('model', None), # same as sharding_names=('model', None)
)

def __call__(self, x: jax.Array):
Expand Down Expand Up @@ -258,8 +261,8 @@ class LogicalDotReluDot(nnx.Module):
init_fn = nnx.initializers.lecun_normal()
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),
use_bias=False, # or use `bias_init` to give it annotation too
kernel_metadata={'sharding_names': ('embed', 'hidden')},
use_bias=False, # or use `bias_metadata` to give it annotation too
rngs=rngs)
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
Expand Down
Loading