Skip to content

Commit a49e06b

Browse files
committed
fix List and improve Pytree
1 parent cd37bc9 commit a49e06b

File tree

4 files changed

+96
-72
lines changed

4 files changed

+96
-72
lines changed

flax/nnx/graph.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,15 @@ def get_node_impl_for_type(
296296
else:
297297
return None
298298

299+
# use type-aware sorting to support int keys
300+
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
301+
key, _ = item
302+
if isinstance(key, int):
303+
return (0, key)
304+
elif isinstance(key, str):
305+
return (1, key)
306+
else:
307+
raise ValueError(f'Unsupported key type: {type(key)!r}')
299308

300309
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
301310
_mapping: dict[HA, HB] | tp.Mapping[HA, HB]
@@ -316,16 +325,7 @@ def __len__(self) -> int:
316325
return len(self._mapping)
317326

318327
def __hash__(self) -> int:
319-
# use type-aware sorting to support int keys
320-
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
321-
key, _ = item
322-
if isinstance(key, int):
323-
return (0, key)
324-
elif isinstance(key, str):
325-
return (1, key)
326-
else:
327-
raise ValueError(f'Unsupported key type: {type(key)!r}')
328-
return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn)))
328+
return hash(tuple(sorted(self._mapping.items(), key=_type_aware_sort)))
329329

330330
def __eq__(self, other: tp.Any) -> bool:
331331
return (

flax/nnx/helpers.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,20 @@ def __init__(self, it: tp.Iterable[A] | None = None, /):
116116
for value in it:
117117
self.append(value)
118118

119-
def _getattr(self, key) -> A:
120-
return vars(self)[key] # type: ignore[unsupported-operands]
119+
def _get_elem(self, key: int) -> A:
120+
return getattr(self, str(key))
121121

122-
def _delattr(self, key) -> None:
123-
vars(self).pop(key)
122+
def _set_elem(self, key: int, value: A) -> None:
123+
setattr(self, str(key), value)
124+
125+
def _del_elem(self, key: int) -> None:
126+
delattr(self, str(key))
124127

125128
def __len__(self) -> int:
126129
return self._length
127130

128131
def append(self, value: A) -> None:
129-
self._setattr(self._length, value)
132+
self._set_elem(self._length, value)
130133
self._length += 1
131134

132135
def insert(self, index: int, value: A) -> None:
@@ -139,15 +142,15 @@ def insert(self, index: int, value: A) -> None:
139142

140143
# Shift elements to the right
141144
for i in range(self._length, index, -1):
142-
self._setattr(i, self._getattr(i - 1))
145+
self._set_elem(i, self._get_elem(i - 1))
143146

144147
# Insert the new value
145-
self._setattr(index, value)
148+
self._set_elem(index, value)
146149
self._length += 1
147150

148151
def __iter__(self) -> tp.Iterator[A]:
149152
for i in range(self._length):
150-
yield self._getattr(i)
153+
yield self._get_elem(i)
151154

152155
@tp.overload
153156
def __getitem__(self, index: int) -> A: ...
@@ -159,10 +162,10 @@ def __getitem__(self, index: int | slice) -> A | tp.List[A]:
159162
index += self._length
160163
if index < 0 or index >= self._length:
161164
raise IndexError('Index out of bounds')
162-
return self._getattr(index)
165+
return self._get_elem(index)
163166
elif isinstance(index, slice):
164167
idxs = list(range(self._length))[index]
165-
return [self._getattr(i) for i in idxs]
168+
return [self._get_elem(i) for i in idxs]
166169
else:
167170
raise TypeError('Invalid index type')
168171

@@ -172,7 +175,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
172175
index += self._length
173176
if index < 0 or index >= self._length:
174177
raise IndexError('Index out of bounds')
175-
self._setattr(index, value)
178+
self._set_elem(index, value)
176179
elif isinstance(index, slice):
177180
if not isinstance(value, tp.Iterable):
178181
raise TypeError('Expected an iterable')
@@ -181,7 +184,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
181184
if len(idxs) != len(values):
182185
raise ValueError('Length mismatch')
183186
for i, v in zip(idxs, values):
184-
self._setattr(i, v)
187+
self._set_elem(i, v)
185188
else:
186189
raise TypeError('Invalid index type')
187190

@@ -191,9 +194,9 @@ def __delitem__(self, index: int | slice) -> None:
191194
index += self._length
192195
if index < 0 or index >= self._length:
193196
raise IndexError('Index out of bounds')
194-
self._delattr(index)
197+
self._del_elem(index)
195198
for i in range(index + 1, self._length):
196-
self._setattr(i - 1, self._getattr(i))
199+
self._set_elem(i - 1, self._get_elem(i))
197200
self._length -= 1
198201
elif isinstance(index, slice):
199202
idxs = list(range(self._length))[index]
@@ -203,15 +206,8 @@ def __delitem__(self, index: int | slice) -> None:
203206
else:
204207
raise TypeError('Invalid index type')
205208

206-
@staticmethod
207-
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
208-
key, _ = item
209-
if isinstance(key, int):
210-
return (0, key)
211-
elif isinstance(key, str):
212-
return (1, key)
213-
else:
214-
raise ValueError(f'Unsupported key type: {type(key)!r}')
209+
_pytree__has_int_keys = True
210+
215211

216212
class Sequential(Module):
217213
"""A Module that applies a sequence of callables.

flax/nnx/pytreelib.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,18 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
340340
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
341341
node = cls.__new__(cls, *args, **kwargs)
342342
vars_obj = vars(node)
343-
vars_obj['_pytree__state'] = PytreeState()
344-
vars_obj['_pytree__nodes'] = cls._pytree__nodes
343+
object.__setattr__(node, '_pytree__state', PytreeState())
344+
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
345345
cls._pytree_meta_construct(node, *args, **kwargs)
346346
if cls._pytree__is_pytree:
347347
missing: dict[str, bool] = {}
348348
for name, value in vars(node).items():
349-
if name not in vars_obj['_pytree__nodes']:
349+
if name not in node._pytree__nodes:
350350
missing[name] = is_data(value)
351351
if missing:
352-
vars_obj['_pytree__nodes'] = vars_obj['_pytree__nodes'].update(missing)
352+
object.__setattr__(
353+
node, '_pytree__nodes', node._pytree__nodes.update(missing)
354+
)
353355
check_pytree(node)
354356

355357
return node
@@ -500,11 +502,10 @@ def _setattr(self, name, value: tp.Any) -> None:
500502
if name not in self._pytree__nodes or (
501503
explicit and self._pytree__nodes[name] != data
502504
):
503-
vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data})
504-
if isinstance(name, str):
505-
object.__setattr__(self, name, value)
506-
else:
507-
vars(self)[name] = value
505+
object.__setattr__(
506+
self, '_pytree__nodes', self._pytree__nodes.update({name: data})
507+
)
508+
object.__setattr__(self, name, value)
508509

509510
def _check_value(self, key, value, new_status: AttributeStatus | None):
510511
def _has_arrays(leaves):
@@ -739,20 +740,26 @@ def __getstate__(self):
739740
return vars(self).copy()
740741

741742
def __setstate__(self, state):
742-
vars(self).update(state)
743+
for key, value in state.items():
744+
object.__setattr__(self, key, value)
743745

744746
# -------------------------
745747
# Pytree Definition
746748
# -------------------------
747-
_pytree__key_sort_fn: tp.Callable | None = None
749+
_pytree__has_int_keys: bool = False
748750

749751
def _pytree__flatten_with_paths(self):
750-
obj_vars = vars(self)
752+
obj_items = vars(self).items()
753+
if self._pytree__has_int_keys:
754+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
755+
key_fn = graph._type_aware_sort
756+
else:
757+
key_fn = None
751758
node_attributes = self._pytree__nodes
752759
node_names: list[str] = []
753760
node_attrs: list[tuple[tp.Any, tp.Any]] = []
754761
static_attrs: list[tuple[str, tp.Any]] = []
755-
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
762+
for name, value in sorted(obj_items, key=key_fn):
756763
if name in node_attributes and node_attributes[name]:
757764
node_names.append(name)
758765
node_attrs.append((
@@ -767,12 +774,17 @@ def _pytree__flatten_with_paths(self):
767774
return node_attrs, (tuple(node_names), tuple(static_attrs))
768775

769776
def _pytree__flatten(self):
770-
obj_vars = vars(self)
777+
obj_items = vars(self).items()
778+
if self._pytree__has_int_keys:
779+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
780+
key_fn = graph._type_aware_sort
781+
else:
782+
key_fn = None
771783
node_attributes = self._pytree__nodes
772784
node_names: list[str] = []
773785
node_attrs: list[tp.Any] = []
774786
static_attrs: list[tuple[str, tp.Any]] = []
775-
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
787+
for name, value in sorted(obj_items, key=key_fn):
776788
if name in node_attributes and node_attributes[name]:
777789
node_names.append(name)
778790
node_attrs.append(value)
@@ -790,45 +802,58 @@ def _pytree__unflatten(
790802
node_names, static_attrs = static
791803
obj = object.__new__(cls)
792804
vars_obj = vars(obj)
793-
vars_obj.update(zip(node_names, node_attrs, strict=True))
794-
vars_obj.update(static_attrs)
805+
if cls._pytree__has_int_keys:
806+
node_names = [
807+
str(name) if isinstance(name, int) else name for name in node_names
808+
]
809+
for name, value in zip(node_names, node_attrs, strict=True):
810+
object.__setattr__(obj, name, value)
811+
for name, value in static_attrs:
812+
object.__setattr__(obj, name, value)
795813
return obj
796814

797815
# -------------------------
798816
# Graph Definition
799817
# -------------------------
800818
def _graph_node_flatten(self):
801-
nodes = vars(self)
802-
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
819+
obj_items = vars(self).items()
820+
if self._pytree__has_int_keys:
821+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
822+
key_fn = graph._type_aware_sort
823+
else:
824+
key_fn = None
825+
nodes = sorted(obj_items, key=key_fn)
803826
return nodes, type(self)
804827

805-
def _graph_node_set_key(self, key: str, value: tp.Any):
806-
if not isinstance(key, str):
807-
raise KeyError(f'Invalid key: {key!r}')
808-
elif (
809-
hasattr(self, key)
810-
and isinstance(variable := getattr(self, key), Variable)
811-
and isinstance(value, Variable)
812-
):
813-
variable.update_from_state(value)
814-
else:
815-
setattr(self, key, value)
828+
def _graph_node_set_key(self, key, value: tp.Any):
829+
if self._pytree__has_int_keys and isinstance(key, int):
830+
key = str(key)
831+
setattr(self, key, value)
816832

817-
def _graph_node_pop_key(self, key: str):
818-
if not isinstance(key, str):
819-
raise KeyError(f'Invalid key: {key!r}')
820-
return vars(self).pop(key)
833+
def _graph_node_pop_key(self, key):
834+
if self._pytree__has_int_keys and isinstance(key, int):
835+
key = str(key)
836+
value = getattr(self, key)
837+
delattr(self, key)
838+
return value
821839

822840
@staticmethod
823841
def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
824842
node = object.__new__(node_type)
825843
return node
826844

827845
def _graph_node_clear(self):
828-
vars(self).clear()
846+
for name in list(vars(self)):
847+
delattr(self, name)
829848

830849
def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
831-
vars(self).update(attributes)
850+
if self._pytree__has_int_keys:
851+
attributes = (
852+
(str(name) if isinstance(name, int) else name, value)
853+
for name, value in attributes
854+
)
855+
for name, value in attributes:
856+
object.__setattr__(self, name, value)
832857

833858
if tp.TYPE_CHECKING:
834859
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
@@ -845,3 +870,9 @@ def __init_subclass__(cls, **kwargs):
845870
f'{pytree!r} for type {cls}.'
846871
)
847872
super().__init_subclass__(pytree=pytree, **kwargs)
873+
874+
def _maybe_int(x):
875+
try:
876+
return int(x)
877+
except (ValueError, TypeError):
878+
return x

tests/nnx/partitioning_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ def test_get_paritition(self):
150150
d=5.0,
151151
)
152152

153-
# test Variables not shared
154-
self.assertIsNot(vars(m.a)[0], vars(m)['b'])
155-
156153
state = nnx.state(m, nnx.Variable)
157154
self.assertEqual(state['a'][0][...], m.a[0][...])
158155
self.assertEqual(state['a'][1][...], m.a[1][...])

0 commit comments

Comments
 (0)