Skip to content

Commit fe5dec2

Browse files
author
Flax Authors
committed
Merge pull request #5061 from Lucas-Fernandes-Martins:main
PiperOrigin-RevId: 826201674
2 parents 3531187 + 21f500c commit fe5dec2

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

flax/nnx/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Model(nnx.Module):
3939

4040

4141
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
42-
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
42+
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) # reference sharing
4343

4444
@nnx.jit # automatic state management
4545
def train_step(model, optimizer, x, y):
@@ -48,7 +48,7 @@ def train_step(model, optimizer, x, y):
4848
return ((y_pred - y) ** 2).mean()
4949

5050
loss, grads = nnx.value_and_grad(loss_fn)(model)
51-
optimizer.update(grads) # inplace updates
51+
optimizer.update(model, grads) # inplace updates
5252

5353
return loss
5454
```

flax/nnx/training/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ def __init__(
144144
Args:
145145
model: An NNX Module.
146146
tx: An Optax gradient transformation.
147-
wrt: optional argument to filter for which :class:`Variable`'s to keep
147+
wrt: filter to specify for which :class:`Variable`'s to keep
148148
track of in the optimizer state. These should be the :class:`Variable`'s
149149
that you plan on updating; i.e. this argument value should match the
150150
``wrt`` argument passed to the ``nnx.grad`` call that will generate the
151151
gradients that will be passed into the ``grads`` argument of the
152-
:func:`update` method.
152+
:func:`update` method. The filter should match the filter used in nnx.grad.
153153
"""
154154
if isinstance(wrt, _Missing):
155155
raise TypeError(

0 commit comments

Comments
 (0)