Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import numpy as np
import typing_extensions as tpe

from flax.core.frozen_dict import FrozenDict
from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
ApplyCaller,
Expand Down Expand Up @@ -183,7 +182,7 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
return _node_impl_for_type[x]


class _HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)

Expand All @@ -204,7 +203,7 @@ def __hash__(self) -> int:

def __eq__(self, other: tp.Any) -> bool:
return (
isinstance(other, _HashableMapping) and self._mapping == other._mapping
isinstance(other, HashableMapping) and self._mapping == other._mapping
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -246,7 +245,7 @@ def __treescope_repr__(self, path, subtree_renderer):
class VariableDef(reprlib.Representable):
type: type[Variable]
index: int
metadata: FrozenDict[str, tp.Any]
metadata: HashableMapping[str, tp.Any]

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
Expand All @@ -272,7 +271,7 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, repr=False)
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
class NodeDef(GraphDef[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
:class:`Module`. A ``GraphDef`` can be generated by either
Expand All @@ -281,11 +280,11 @@ class NodeDef(GraphDef[Node], reprlib.Representable):
type: tp.Type[Node]
index: int
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: _HashableMapping[Key, tp.Any]
leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: HashableMapping[Key, tp.Any]
leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
metadata: tp.Any
index_mapping: FrozenDict[Index, Index] | None
index_mapping: HashableMapping[Index, Index] | None

@classmethod
def create(
Expand All @@ -303,11 +302,11 @@ def create(
type=type,
index=index,
attributes=attributes,
subgraphs=_HashableMapping(subgraphs),
static_fields=_HashableMapping(static_fields),
leaves=_HashableMapping(leaves),
subgraphs=HashableMapping(subgraphs),
static_fields=HashableMapping(static_fields),
leaves=HashableMapping(leaves),
metadata=metadata,
index_mapping=FrozenDict(index_mapping)
index_mapping=HashableMapping(index_mapping)
if index_mapping is not None
else None,
)
Expand Down Expand Up @@ -424,7 +423,7 @@ def _graph_flatten(
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, FrozenDict(value.get_metadata())
type(value), variable_index, HashableMapping(value.get_metadata())
)
leaves.append((key, variabledef))
else:
Expand Down Expand Up @@ -794,7 +793,7 @@ def split(
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=FrozenDict(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index)
)

return graphdef, *states
Expand Down Expand Up @@ -984,7 +983,7 @@ def split(
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=FrozenDict(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index)
)

self.flatten_end(ref_index)
Expand Down
15 changes: 8 additions & 7 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import (
extract,
filterlib,
Expand Down Expand Up @@ -428,7 +427,7 @@ def _custom_vjp_split_fn(
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)

def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]):
def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]):
if isinstance(x, graph.NodeDef):
assert x.index_mapping is not None
index_mappings.append(x.index_mapping)
Expand Down Expand Up @@ -466,7 +465,9 @@ def __call__(self, *pure_args):
(args_out, out), ctxtag=self.ctxtag
)
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag)
index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state(
self.ctxtag
)

pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_index_mappings, index_mappings=index_mappings),
Expand Down Expand Up @@ -519,8 +520,8 @@ def __call__(self, *pure_args):

if update_context_active:
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(
self.ctxtag
index_mappings: deque[graph.HashableMapping] = (
extract.get_broadcast_state(self.ctxtag)
)
pure_args_out, pure_out = jax.tree.map(
functools.partial(
Expand Down Expand Up @@ -631,7 +632,7 @@ def __call__(
for i, x in enumerate(tree_node_args)
if i not in self.jax_nondiff_argnums
)
index_mappings: deque[FrozenDict] = deque()
index_mappings: deque[graph.HashableMapping] = deque()
with extract.broadcast_state(self.ctxtag, index_mappings):
if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
raise ValueError()
Expand Down Expand Up @@ -663,7 +664,7 @@ def __call__(
# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graph.NodeDef):
index_mapping: FrozenDict = index_mappings.popleft()
index_mapping: graph.HashableMapping = index_mappings.popleft()
return dataclasses.replace(x, index_mapping=index_mapping)
return x

Expand Down
16 changes: 9 additions & 7 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def check_carry_same_references(key_path, arg, out):

def _extract_index_mappings(
pure_carry_arg_out,
carry_index_mappings: list[FrozenDict[int, int]],
carry_index_mappings: list[graph.HashableMapping[int, int]],
/,
):
def extract_index_mappings(x):
Expand All @@ -675,7 +675,7 @@ def extract_index_mappings(x):

def _insert_index_mappings(
pure_carry_arg_out,
carry_index_mappings: deque[FrozenDict[int, int]],
carry_index_mappings: deque[graph.HashableMapping[int, int]],
/,
):
def insert_index_mappings(x):
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def __call__(

# next we have to remove all the index_mappings from the NodeDefs
# in the carry outputs because they are not present in the inputs
carry_index_mappings: list[FrozenDict[int, int]] = []
carry_index_mappings: list[graph.HashableMapping[int, int]] = []
pure_carry_arg_out = _extract_index_mappings(
pure_carry_arg_out, carry_index_mappings
)
Expand Down Expand Up @@ -1347,10 +1347,12 @@ def per_node_def(nd: graph.NodeDef | tp.Any):
return

per_node_def(ns._graphdef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef,
index_mapping=FrozenDict(global_index_mapping)
))
return dataclasses.replace(
ns,
_graphdef=dataclasses.replace(
ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping)
),
)

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
Expand Down
14 changes: 10 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading