Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def get_node_impl_for_type(
else:
return None

# use type-aware sorting to support int keys
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')

class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
_mapping: dict[HA, HB] | tp.Mapping[HA, HB]
Expand All @@ -316,16 +325,7 @@ def __len__(self) -> int:
return len(self._mapping)

def __hash__(self) -> int:
# use type-aware sorting to support int keys
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')
return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn)))
return hash(tuple(sorted(self._mapping.items(), key=_type_aware_sort)))

def __eq__(self, other: tp.Any) -> bool:
return (
Expand Down
42 changes: 19 additions & 23 deletions flax/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,20 @@ def __init__(self, it: tp.Iterable[A] | None = None, /):
for value in it:
self.append(value)

def _getattr(self, key) -> A:
return vars(self)[key] # type: ignore[unsupported-operands]
def _get_elem(self, key: int) -> A:
return getattr(self, str(key))

def _delattr(self, key) -> None:
vars(self).pop(key)
def _set_elem(self, key: int, value: A) -> None:
setattr(self, str(key), value)

def _del_elem(self, key: int) -> None:
delattr(self, str(key))

def __len__(self) -> int:
return self._length

def append(self, value: A) -> None:
self._setattr(self._length, value)
self._set_elem(self._length, value)
self._length += 1

def insert(self, index: int, value: A) -> None:
Expand All @@ -139,15 +142,15 @@ def insert(self, index: int, value: A) -> None:

# Shift elements to the right
for i in range(self._length, index, -1):
self._setattr(i, self._getattr(i - 1))
self._set_elem(i, self._get_elem(i - 1))

# Insert the new value
self._setattr(index, value)
self._set_elem(index, value)
self._length += 1

def __iter__(self) -> tp.Iterator[A]:
for i in range(self._length):
yield self._getattr(i)
yield self._get_elem(i)

@tp.overload
def __getitem__(self, index: int) -> A: ...
Expand All @@ -159,10 +162,10 @@ def __getitem__(self, index: int | slice) -> A | tp.List[A]:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
return self._getattr(index)
return self._get_elem(index)
elif isinstance(index, slice):
idxs = list(range(self._length))[index]
return [self._getattr(i) for i in idxs]
return [self._get_elem(i) for i in idxs]
else:
raise TypeError('Invalid index type')

Expand All @@ -172,7 +175,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
self._setattr(index, value)
self._set_elem(index, value)
elif isinstance(index, slice):
if not isinstance(value, tp.Iterable):
raise TypeError('Expected an iterable')
Expand All @@ -181,7 +184,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
if len(idxs) != len(values):
raise ValueError('Length mismatch')
for i, v in zip(idxs, values):
self._setattr(i, v)
self._set_elem(i, v)
else:
raise TypeError('Invalid index type')

Expand All @@ -191,9 +194,9 @@ def __delitem__(self, index: int | slice) -> None:
index += self._length
if index < 0 or index >= self._length:
raise IndexError('Index out of bounds')
self._delattr(index)
self._del_elem(index)
for i in range(index + 1, self._length):
self._setattr(i - 1, self._getattr(i))
self._set_elem(i - 1, self._get_elem(i))
self._length -= 1
elif isinstance(index, slice):
idxs = list(range(self._length))[index]
Expand All @@ -203,15 +206,8 @@ def __delitem__(self, index: int | slice) -> None:
else:
raise TypeError('Invalid index type')

@staticmethod
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')
_pytree__has_int_keys = True


class Sequential(Module):
"""A Module that applies a sequence of callables.
Expand Down
103 changes: 67 additions & 36 deletions flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,18 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
node = cls.__new__(cls, *args, **kwargs)
vars_obj = vars(node)
vars_obj['_pytree__state'] = PytreeState()
vars_obj['_pytree__nodes'] = cls._pytree__nodes
object.__setattr__(node, '_pytree__state', PytreeState())
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
cls._pytree_meta_construct(node, *args, **kwargs)
if cls._pytree__is_pytree:
missing: dict[str, bool] = {}
for name, value in vars(node).items():
if name not in vars_obj['_pytree__nodes']:
if name not in node._pytree__nodes:
missing[name] = is_data(value)
if missing:
vars_obj['_pytree__nodes'] = vars_obj['_pytree__nodes'].update(missing)
object.__setattr__(
node, '_pytree__nodes', node._pytree__nodes.update(missing)
)
check_pytree(node)

return node
Expand Down Expand Up @@ -500,11 +502,10 @@ def _setattr(self, name, value: tp.Any) -> None:
if name not in self._pytree__nodes or (
explicit and self._pytree__nodes[name] != data
):
vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data})
if isinstance(name, str):
object.__setattr__(self, name, value)
else:
vars(self)[name] = value
object.__setattr__(
self, '_pytree__nodes', self._pytree__nodes.update({name: data})
)
object.__setattr__(self, name, value)

def _check_value(self, key, value, new_status: AttributeStatus | None):
def _has_arrays(leaves):
Expand Down Expand Up @@ -739,20 +740,26 @@ def __getstate__(self):
return vars(self).copy()

def __setstate__(self, state):
vars(self).update(state)
for key, value in state.items():
object.__setattr__(self, key, value)

# -------------------------
# Pytree Definition
# -------------------------
_pytree__key_sort_fn: tp.Callable | None = None
_pytree__has_int_keys: bool = False

def _pytree__flatten_with_paths(self):
obj_vars = vars(self)
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_attrs: list[tuple[tp.Any, tp.Any]] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
for name, value in sorted(obj_items, key=key_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
node_attrs.append((
Expand All @@ -767,12 +774,17 @@ def _pytree__flatten_with_paths(self):
return node_attrs, (tuple(node_names), tuple(static_attrs))

def _pytree__flatten(self):
obj_vars = vars(self)
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
node_attributes = self._pytree__nodes
node_names: list[str] = []
node_attrs: list[tp.Any] = []
static_attrs: list[tuple[str, tp.Any]] = []
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
for name, value in sorted(obj_items, key=key_fn):
if name in node_attributes and node_attributes[name]:
node_names.append(name)
node_attrs.append(value)
Expand All @@ -790,45 +802,58 @@ def _pytree__unflatten(
node_names, static_attrs = static
obj = object.__new__(cls)
vars_obj = vars(obj)
vars_obj.update(zip(node_names, node_attrs, strict=True))
vars_obj.update(static_attrs)
if cls._pytree__has_int_keys:
node_names = [
str(name) if isinstance(name, int) else name for name in node_names
]
for name, value in zip(node_names, node_attrs, strict=True):
object.__setattr__(obj, name, value)
for name, value in static_attrs:
object.__setattr__(obj, name, value)
return obj

# -------------------------
# Graph Definition
# -------------------------
def _graph_node_flatten(self):
nodes = vars(self)
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
obj_items = vars(self).items()
if self._pytree__has_int_keys:
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
key_fn = graph._type_aware_sort
else:
key_fn = None
nodes = sorted(obj_items, key=key_fn)
return nodes, type(self)

def _graph_node_set_key(self, key: str, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
hasattr(self, key)
and isinstance(variable := getattr(self, key), Variable)
and isinstance(value, Variable)
):
variable.update_from_state(value)
else:
setattr(self, key, value)
def _graph_node_set_key(self, key, value: tp.Any):
if self._pytree__has_int_keys and isinstance(key, int):
key = str(key)
setattr(self, key, value)

def _graph_node_pop_key(self, key: str):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
return vars(self).pop(key)
def _graph_node_pop_key(self, key):
if self._pytree__has_int_keys and isinstance(key, int):
key = str(key)
value = getattr(self, key)
delattr(self, key)
return value

@staticmethod
def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
node = object.__new__(node_type)
return node

def _graph_node_clear(self):
vars(self).clear()
for name in list(vars(self)):
delattr(self, name)

def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
vars(self).update(attributes)
if self._pytree__has_int_keys:
attributes = (
(str(name) if isinstance(name, int) else name, value)
for name, value in attributes
)
for name, value in attributes:
object.__setattr__(self, name, value)

if tp.TYPE_CHECKING:
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
Expand All @@ -845,3 +870,9 @@ def __init_subclass__(cls, **kwargs):
f'{pytree!r} for type {cls}.'
)
super().__init_subclass__(pytree=pytree, **kwargs)

def _maybe_int(x):
try:
return int(x)
except (ValueError, TypeError):
return x
3 changes: 0 additions & 3 deletions tests/nnx/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def test_get_paritition(self):
d=5.0,
)

# test Variables not shared
self.assertIsNot(vars(m.a)[0], vars(m)['b'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't make sense here, probably copied from another test in the past


state = nnx.state(m, nnx.Variable)
self.assertEqual(state['a'][0][...], m.a[0][...])
self.assertEqual(state['a'][1][...], m.a[1][...])
Expand Down
Loading