Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def flatten( # type: ignore[invalid-annotation]
*,
ref_index: RefMap | None = None,
ref_outer_index: RefMap | None = None,
convert_to_lojax: bool = False,
) -> tuple[GraphDef[Node], FlatState[tp.Any]]: ...
@tp.overload
def flatten( # type: ignore[invalid-annotation]
Expand All @@ -632,6 +633,7 @@ def flatten( # type: ignore[invalid-annotation]
with_paths: tp.Literal[True],
ref_index: RefMap | None = None,
ref_outer_index: RefMap | None = None,
convert_to_lojax: bool = False,
) -> tuple[
GraphDef[Node],
FlatState[tp.Any],
Expand All @@ -644,6 +646,7 @@ def flatten( # type: ignore[invalid-annotation]
with_paths: tp.Literal[False],
ref_index: RefMap | None = None,
ref_outer_index: RefMap | None = None,
convert_to_lojax: bool = False,
) -> tuple[
GraphDef[Node],
list[tp.Any],
Expand All @@ -656,6 +659,7 @@ def flatten( # type: ignore[invalid-annotation]
with_paths: bool,
ref_index: RefMap | None = None,
ref_outer_index: RefMap | None = None,
convert_to_lojax: bool = False,
) -> tuple[
GraphDef[Node],
FlatState[tp.Any] | list[tp.Any],
Expand All @@ -667,6 +671,7 @@ def flatten( # type: ignore[invalid-annotation]
with_paths: bool = True,
ref_index: RefMap | None = None,
ref_outer_index: RefMap | None = None,
convert_to_lojax: bool = False,
) -> tuple[
GraphDef[Node],
FlatState[tp.Any] | list[tp.Any],
Expand Down Expand Up @@ -700,6 +705,7 @@ def flatten( # type: ignore[invalid-annotation]
attributes,
leaves,
paths,
convert_to_lojax,
)
graphdef: GraphDef = GraphDef(
nodes=nodes, attributes=attributes, num_leaves=len(leaves)
Expand All @@ -721,6 +727,7 @@ def _graph_flatten(
attributes: list[tuple[Key, AttrType]],
leaves: list[tp.Any],
paths: list[PathParts] | None,
convert_to_lojax: bool,
) -> None:
is_pytree_node_ = type(node_impl) is PytreeNodeImpl

Expand Down Expand Up @@ -777,6 +784,8 @@ def make_mutable_arraydef(value: variablelib.Ref):
leaf = node # type: ignore[assignment]
if inner_value is not prev_inner_value:
leaf.set_raw_value(inner_value)
if convert_to_lojax and leaf.is_hijax:
leaf = variablelib._get_hijax_state(leaf)

variabledef = VariableDef(
type=node.var_type, # type: ignore
Expand Down Expand Up @@ -842,6 +851,7 @@ def make_mutable_arraydef(value: variablelib.Ref):
attributes,
leaves,
paths,
convert_to_lojax,
)
elif variablelib.is_array_ref(value):
attributes.append((key, MUTABLE_ARRAY_ATTR))
Expand Down Expand Up @@ -1092,6 +1102,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref: IndexMap | None = None,
outer_index_outer_ref: IndexMap | None = None,
copy_variables: bool = False,
convert_to_hijax: bool = False,
) -> Node:
"""Unflattens a graphdef into a node with the given state.

Expand Down Expand Up @@ -1150,6 +1161,7 @@ def unflatten( # type: ignore[invalid-annotation]
index_ref,
outer_index_outer_ref,
copy_variables,
convert_to_hijax,
)

try:
Expand All @@ -1171,6 +1183,7 @@ def _graph_unflatten(
index_ref: IndexMap,
outer_index_outer_ref: IndexMap | None,
copy_variables: bool,
convert_to_hijax: bool,
) -> Node:
"""Recursive helper for graph_unflatten.

Expand Down Expand Up @@ -1271,6 +1284,8 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
variable = variabledef.type.from_metadata(
value, dict(variabledef.metadata)
)
if convert_to_hijax and variable.is_hijax:
variable = variablelib._new_hijax_from_variable(variable)
index_ref[variabledef.index] = variable
return variable # type: ignore[return-value]

Expand Down Expand Up @@ -1326,6 +1341,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
index_ref,
outer_index_outer_ref,
copy_variables,
convert_to_hijax,
)
children.append((key, subnode))
else:
Expand Down Expand Up @@ -1696,7 +1712,10 @@ def split(
ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None
)
graphdef, flat_state = flatten(
node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index
node,
ref_index=self.ref_index,
ref_outer_index=inner_ref_outer_index,
convert_to_lojax=True,
)
flat_states = _split_state(flat_state, filters)
states = _to_nested_state(graphdef, flat_states)
Expand Down Expand Up @@ -1772,6 +1791,7 @@ def flatten( # type: ignore[invalid-annotation]
ref_index=self.ref_index,
ref_outer_index=ref_outer_index,
with_paths=with_paths,
convert_to_lojax=True,
)
if with_paths:
assert isinstance(flat_state, FlatState)
Expand Down Expand Up @@ -1801,6 +1821,7 @@ def flatten( # type: ignore[invalid-annotation]
ref_index=self.ref_index,
ref_outer_index=ref_outer_index,
with_paths=with_paths,
convert_to_lojax=True,
)
if with_paths:
assert isinstance(flat_state, FlatState)
Expand Down Expand Up @@ -1864,6 +1885,7 @@ def merge( # type: ignore[invalid-annotation]
index_ref=self.index_ref,
outer_index_outer_ref=outer_index_outer_ref,
copy_variables=True,
convert_to_hijax=True,
)
return node

Expand Down Expand Up @@ -1896,6 +1918,7 @@ def unflatten( # type: ignore[invalid-annotation]
graphdef,
state,
index_ref=self.index_ref,
convert_to_hijax=True,
)

elif static_cache is not None:
Expand Down Expand Up @@ -1938,13 +1961,15 @@ def unflatten( # type: ignore[invalid-annotation]
state,
index_ref=self.index_ref,
outer_index_outer_ref=outer_index_outer_ref,
convert_to_hijax=True,
)
else: # graphdef.outer_index is None
# its a new node, create it
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
convert_to_hijax=True,
)
else:
outer_index_outer_ref = (
Expand All @@ -1955,6 +1980,7 @@ def unflatten( # type: ignore[invalid-annotation]
state,
index_ref=self.index_ref,
outer_index_outer_ref=outer_index_outer_ref,
convert_to_hijax=True,
)
return node

Expand Down
12 changes: 12 additions & 0 deletions tests/nnx/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,18 @@ def f(v):

self.assertEqual(y.shape, ())

@nnx.use_hijax(True)
def test_nnx_jit(self):
v = nnx.Param(jnp.array([1, 2, 3]))

@nnx.vmap(in_axes=(0,))
def f(v):
v[...] += 1

f(v)

self.assertEqual(v[...], 1)


if __name__ == '__main__':
absltest.main()