Skip to content

Commit 86bc732

Browse files
committed
[nnx] remove VariableState
1 parent 99f8e59 commit 86bc732

25 files changed

+323
-648
lines changed

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ def init_optimizer_state(variable: nnx.Variable):
8787
self.momentum: nnx.State = jax.tree.map(
8888
init_optimizer_state,
8989
self.params,
90-
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
90+
is_leaf=lambda x: isinstance(x, nnx.Variable),
9191
)
9292
self.decay = decay
9393

9494
def update(self, grads: nnx.State):
9595
def update_fn(
96-
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
96+
params: nnx.Variable, momentum: SGDState, grad: nnx.Variable
9797
):
9898
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
9999
momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...]
@@ -105,7 +105,7 @@ def update_fn(
105105
self.params,
106106
self.momentum,
107107
grads,
108-
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
108+
is_leaf=lambda x: isinstance(x, nnx.Variable),
109109
)
110110

111111

@@ -118,12 +118,12 @@ def create_model():
118118
state, nnx.get_named_sharding(state, mesh)
119119
)
120120

121-
def get_named_shardings(path: tuple, value: nnx.VariableState):
121+
def get_named_shardings(path: tuple, value: nnx.Variable):
122122
if path[0] == 'params':
123-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
123+
return NamedSharding(mesh, P(*value.sharding))
124124
elif path[0] == 'momentum':
125125
# currently the same as above but in general it could be different
126-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
126+
return NamedSharding(mesh, P(*value.sharding))
127127
else:
128128
raise ValueError(f'Unknown path: {path}')
129129

examples/nnx_toy_examples/mutable_array_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(self, params, lr: float, decay: float = 0.9):
192192
self.decay = decay
193193

194194
def make_opt_state(x):
195-
if isinstance(x, nnx.Variable | nnx.VariableState):
195+
if isinstance(x, nnx.Variable):
196196
return OptState(jnp.zeros_like(x.value), **x.get_metadata())
197197
else:
198198
return OptState(jnp.zeros_like(x))
@@ -201,7 +201,7 @@ def make_opt_state(x):
201201
jax.tree.map(
202202
make_opt_state,
203203
params,
204-
is_leaf=lambda x: isinstance(x, nnx.Variable | nnx.VariableState),
204+
is_leaf=lambda x: isinstance(x, nnx.Variable),
205205
)
206206
)
207207

flax/nnx/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@
172172
from .variablelib import Intermediate as Intermediate
173173
from .variablelib import Perturbation as Perturbation
174174
from .variablelib import Variable as Variable
175-
from .variablelib import VariableState as VariableState
176175
from .variablelib import VariableMetadata as VariableMetadata
177176
from .variablelib import with_metadata as with_metadata
178177
from .variablelib import variable_type_from_name as variable_type_from_name

flax/nnx/bridge/module.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -374,27 +374,27 @@ def _get_variables(self) -> tp.Mapping:
374374
state = graph.state(self)
375375
_variables: dict = {}
376376

377-
variable_state: variablelib.VariableState
378-
for path, variable_state in statelib.to_flat_state(state):
379-
if issubclass(variable_state.type, rnglib.RngState):
377+
variable: variablelib.Variable
378+
for path, variable in statelib.to_flat_state(state):
379+
if isinstance(variable, rnglib.RngState):
380380
# Don't return RNG states, since Linen doesn't have them.
381381
continue
382382

383383
try:
384-
collection = variablelib.variable_name_from_type(variable_state.type)
384+
collection = variablelib.variable_name_from_type(type(variable))
385385
except ValueError:
386-
collection = variable_state.type.__name__
386+
collection = type(variable).__name__
387387

388388
if collection not in _variables:
389389
_variables[collection] = {}
390390

391391
if (
392-
isinstance(variable_state, variablelib.VariableState)
393-
and not variable_state._var_metadata
392+
isinstance(variable, variablelib.Variable)
393+
and not variable._var_metadata
394394
):
395-
leaf = variable_state.value
395+
leaf = variable.value
396396
else:
397-
leaf = bridge_variables.to_linen_var(variable_state)
397+
leaf = bridge_variables.to_linen_var(variable)
398398

399399
_variables[collection][path] = leaf
400400

flax/nnx/bridge/variables.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _variable_parents_count(t: type):
4343

4444

4545
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
46-
"""Default Flax metadata class for `nnx.VariableState`."""
46+
"""Default Flax metadata class for `nnx.Variable`."""
4747

4848
var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False)
4949
value: Any = struct.field(pytree_node=True)
@@ -65,15 +65,17 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
6565

6666
def get_partition_spec(self) -> jax.sharding.PartitionSpec:
6767
"""Returns the ``Partitionspec`` for this partitioned value."""
68-
nnx_var = self.to_nnx_variable().to_state()
69-
return spmd.get_partition_spec(nnx_var).raw_value
68+
nnx_var = self.to_nnx_variable()
69+
spec = spmd.get_partition_spec(nnx_var)
70+
assert isinstance(spec, jax.sharding.PartitionSpec)
71+
return spec
7072

7173
def to_nnx_variable(self) -> variablelib.Variable:
7274
return self.var_type(self.value, **self.metadata)
7375

7476

75-
def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
76-
"""A variables state is vanilla if its metadata is essentially blank.
77+
def is_vanilla_variable(vs: variablelib.Variable) -> bool:
78+
"""A variable is vanilla if its metadata is essentially blank.
7779
7880
Returns False only if it has non-empty hooks or any non-built-in attribute.
7981
"""
@@ -86,7 +88,7 @@ def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
8688
return True
8789

8890

89-
def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata:
91+
def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
9092
metadata = vs.get_metadata()
9193
if 'linen_meta_type' in metadata:
9294
linen_type = metadata['linen_meta_type']
@@ -145,14 +147,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
145147

146148

147149
def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
148-
"""Convert a dict of NNX variables (or variable states) to Linen-style variables."""
150+
"""Convert a dict of NNX variables to Linen-style variables."""
149151
linen_structured = {}
150152
for kp, v in traversals.flatten_mapping(nnx_attrs).items():
151153
if isinstance(v, variablelib.Variable):
152154
col_name = variablelib.variable_name_from_type(type(v))
153-
v = to_linen_var(v.to_state())
154-
elif isinstance(v, variablelib.VariableState):
155-
col_name = variablelib.variable_name_from_type(v.type)
156155
v = to_linen_var(v)
157156
elif isinstance(v, graph.GraphDef):
158157
col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule

flax/nnx/bridge/wrappers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class ToLinen(linen.Module):
252252
args: tp.Sequence = ()
253253
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
254254
skip_rng: bool = False
255-
metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = (
255+
metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = (
256256
bv.to_linen_var
257257
)
258258

@@ -310,7 +310,7 @@ def _update_variables(self, module):
310310

311311
# group state by collection
312312
for path, leaf in nnx.to_flat_state(state):
313-
type_ = leaf.type if isinstance(leaf, nnx.VariableState) else type(leaf)
313+
type_ = leaf.type if isinstance(leaf, nnx.Variable) else type(leaf)
314314
collection = variablelib.variable_name_from_type(
315315
type_, allow_register=True
316316
)
@@ -323,7 +323,7 @@ def _update_variables(self, module):
323323
if self.is_mutable_collection(collection):
324324

325325
def _to_linen_var(x):
326-
if isinstance(x, nnx.VariableState):
326+
if isinstance(x, nnx.Variable):
327327
if self.metadata_fn:
328328
return self.metadata_fn(x)
329329
else:
@@ -334,7 +334,7 @@ def _to_linen_var(x):
334334
collection_state = jax.tree.map(
335335
_to_linen_var,
336336
collection_state,
337-
is_leaf=lambda x: isinstance(x, nnx.VariableState),
337+
is_leaf=lambda x: isinstance(x, nnx.Variable),
338338
)
339339
for k, v in collection_state.items():
340340
self.put_variable(collection, k, v)
@@ -344,7 +344,7 @@ def to_linen(
344344
nnx_class: tp.Callable[..., Module],
345345
*args,
346346
metadata_fn: (
347-
tp.Callable[[variablelib.VariableState], tp.Any] | None
347+
tp.Callable[[variablelib.Variable], tp.Any] | None
348348
) = bv.to_linen_var,
349349
name: str | None = None,
350350
**kwargs,

flax/nnx/filterlib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ class OfType:
121121
type: type
122122

123123
def __call__(self, path: PathParts, x: tp.Any):
124-
return isinstance(x, self.type) or (
125-
hasattr(x, 'type') and issubclass(x.type, self.type)
126-
)
124+
return isinstance(x, self.type)
127125

128126
def __repr__(self):
129127
return f'OfType({self.type!r})'

0 commit comments

Comments
 (0)