Skip to content

Documentation Inconsistency regarding flax.nnx.optimizer.Optimizer #5060

@Lucas-Fernandes-Martins

Description

Hi everyone,

I believe I found an inconsistency between the implementation of the constructor offlax.nnx.optimizer.Optimizer and what the documentation states.

According to the docs, wrt is supposed to be an optional argument.

However, looking at the class constructor:

@_check_wrt_arg_passed
  def __init__(
    self,
    model: M,
    tx: optax.GradientTransformation,
    *,
    wrt: filterlib.Filter,  # type: ignore
  ):

wrt is actually a required argument.

Ommiting wrt when creating an instance of Optmizer will yield the following error:

TypeError: Missing required argument `wrt`. As of Flax 0.11.0 the `wrt` argument is required, if you want to keep the previous use                    
nnx.ModelAndOptimizer instead of nnx.Optimizer.  

Please let me know if I'm getting something wrong here. Otherwise, I'd be happy to help updating the docs accordingly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions