Hi everyone,
I believe I found an inconsistency between the implementation of the constructor offlax.nnx.optimizer.Optimizer and what the documentation states.
According to the docs, wrt is supposed to be an optional argument.
However, looking at the class constructor:
@_check_wrt_arg_passed
def __init__(
self,
model: M,
tx: optax.GradientTransformation,
*,
wrt: filterlib.Filter, # type: ignore
):
wrt is actually a required argument.
Ommiting wrt when creating an instance of Optmizer will yield the following error:
TypeError: Missing required argument `wrt`. As of Flax 0.11.0 the `wrt` argument is required, if you want to keep the previous use
nnx.ModelAndOptimizer instead of nnx.Optimizer.
Please let me know if I'm getting something wrong here. Otherwise, I'd be happy to help updating the docs accordingly.