Skip to content

Commit f5fbdf2

Browse files
committed
fix List and improve Pytree
1 parent 74985b2 commit f5fbdf2

File tree

3 files changed

+78
-52
lines changed

3 files changed

+78
-52
lines changed

flax/nnx/helpers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,20 @@ def __init__(self, it: tp.Iterable[A] | None = None, /):
116116
for value in it:
117117
self.append(value)
118118

119-
def _getattr(self, key) -> A:
120-
return vars(self)[key] # type: ignore[unsupported-operands]
119+
def _get_elem(self, key: int) -> A:
120+
return getattr(self, str(key))
121121

122-
def _delattr(self, key) -> None:
123-
vars(self).pop(key)
122+
def _set_elem(self, key: int, value: A) -> None:
123+
setattr(self, str(key), value)
124+
125+
def _del_elem(self, key: int) -> None:
126+
delattr(self, str(key))
124127

125128
def __len__(self) -> int:
126129
return self._length
127130

128131
def append(self, value: A) -> None:
129-
self._setattr(self._length, value)
132+
self._set_elem(self._length, value)
130133
self._length += 1
131134

132135
def insert(self, index: int, value: A) -> None:
@@ -139,15 +142,15 @@ def insert(self, index: int, value: A) -> None:
139142

140143
# Shift elements to the right
141144
for i in range(self._length, index, -1):
142-
self._setattr(i, self._getattr(i - 1))
145+
self._set_elem(i, self._get_elem(i - 1))
143146

144147
# Insert the new value
145-
self._setattr(index, value)
148+
self._set_elem(index, value)
146149
self._length += 1
147150

148151
def __iter__(self) -> tp.Iterator[A]:
149152
for i in range(self._length):
150-
yield self._getattr(i)
153+
yield self._get_elem(i)
151154

152155
@tp.overload
153156
def __getitem__(self, index: int) -> A: ...
@@ -159,10 +162,10 @@ def __getitem__(self, index: int | slice) -> A | tp.List[A]:
159162
index += self._length
160163
if index < 0 or index >= self._length:
161164
raise IndexError('Index out of bounds')
162-
return self._getattr(index)
165+
return self._get_elem(index)
163166
elif isinstance(index, slice):
164167
idxs = list(range(self._length))[index]
165-
return [self._getattr(i) for i in idxs]
168+
return [self._get_elem(i) for i in idxs]
166169
else:
167170
raise TypeError('Invalid index type')
168171

@@ -172,7 +175,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
172175
index += self._length
173176
if index < 0 or index >= self._length:
174177
raise IndexError('Index out of bounds')
175-
self._setattr(index, value)
178+
self._set_elem(index, value)
176179
elif isinstance(index, slice):
177180
if not isinstance(value, tp.Iterable):
178181
raise TypeError('Expected an iterable')
@@ -181,7 +184,7 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
181184
if len(idxs) != len(values):
182185
raise ValueError('Length mismatch')
183186
for i, v in zip(idxs, values):
184-
self._setattr(i, v)
187+
self._set_elem(i, v)
185188
else:
186189
raise TypeError('Invalid index type')
187190

@@ -191,9 +194,9 @@ def __delitem__(self, index: int | slice) -> None:
191194
index += self._length
192195
if index < 0 or index >= self._length:
193196
raise IndexError('Index out of bounds')
194-
self._delattr(index)
197+
self._del_elem(index)
195198
for i in range(index + 1, self._length):
196-
self._setattr(i - 1, self._getattr(i))
199+
self._set_elem(i - 1, self._get_elem(i))
197200
self._length -= 1
198201
elif isinstance(index, slice):
199202
idxs = list(range(self._length))[index]
@@ -213,6 +216,9 @@ def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
213216
else:
214217
raise ValueError(f'Unsupported key type: {type(key)!r}')
215218

219+
_pytree__has_int_keys = True
220+
221+
216222
class Sequential(Module):
217223
"""A Module that applies a sequence of callables.
218224

flax/nnx/pytreelib.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,18 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
340340
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
341341
node = cls.__new__(cls, *args, **kwargs)
342342
vars_obj = vars(node)
343-
vars_obj['_pytree__state'] = PytreeState()
344-
vars_obj['_pytree__nodes'] = cls._pytree__nodes
343+
object.__setattr__(node, '_pytree__state', PytreeState())
344+
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
345345
cls._pytree_meta_construct(node, *args, **kwargs)
346346
if cls._pytree__is_pytree:
347347
missing: dict[str, bool] = {}
348348
for name, value in vars(node).items():
349-
if name not in vars_obj['_pytree__nodes']:
349+
if name not in node._pytree__nodes:
350350
missing[name] = is_data(value)
351351
if missing:
352-
vars_obj['_pytree__nodes'] = vars_obj['_pytree__nodes'].update(missing)
352+
object.__setattr__(
353+
node, '_pytree__nodes', node._pytree__nodes.update(missing)
354+
)
353355
check_pytree(node)
354356

355357
return node
@@ -500,11 +502,10 @@ def _setattr(self, name, value: tp.Any) -> None:
500502
if name not in self._pytree__nodes or (
501503
explicit and self._pytree__nodes[name] != data
502504
):
503-
vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data})
504-
if isinstance(name, str):
505-
object.__setattr__(self, name, value)
506-
else:
507-
vars(self)[name] = value
505+
object.__setattr__(
506+
self, '_pytree__nodes', self._pytree__nodes.update({name: data})
507+
)
508+
object.__setattr__(self, name, value)
508509

509510
def _check_value(self, key, value, new_status: AttributeStatus | None):
510511
def _has_arrays(leaves):
@@ -739,20 +740,24 @@ def __getstate__(self):
739740
return vars(self).copy()
740741

741742
def __setstate__(self, state):
742-
vars(self).update(state)
743+
for key, value in state.items():
744+
object.__setattr__(self, key, value)
743745

744746
# -------------------------
745747
# Pytree Definition
746748
# -------------------------
747749
_pytree__key_sort_fn: tp.Callable | None = None
750+
_pytree__has_int_keys: bool = False
748751

749752
def _pytree__flatten_with_paths(self):
750-
obj_vars = vars(self)
753+
obj_items = vars(self).items()
754+
if self._pytree__has_int_keys:
755+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
751756
node_attributes = self._pytree__nodes
752757
node_names: list[str] = []
753758
node_attrs: list[tuple[tp.Any, tp.Any]] = []
754759
static_attrs: list[tuple[str, tp.Any]] = []
755-
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
760+
for name, value in sorted(obj_items, key=self._pytree__key_sort_fn):
756761
if name in node_attributes and node_attributes[name]:
757762
node_names.append(name)
758763
node_attrs.append((
@@ -767,12 +772,14 @@ def _pytree__flatten_with_paths(self):
767772
return node_attrs, (tuple(node_names), tuple(static_attrs))
768773

769774
def _pytree__flatten(self):
770-
obj_vars = vars(self)
775+
obj_items = vars(self).items()
776+
if self._pytree__has_int_keys:
777+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
771778
node_attributes = self._pytree__nodes
772779
node_names: list[str] = []
773780
node_attrs: list[tp.Any] = []
774781
static_attrs: list[tuple[str, tp.Any]] = []
775-
for name, value in sorted(obj_vars.items(), key=self._pytree__key_sort_fn):
782+
for name, value in sorted(obj_items, key=self._pytree__key_sort_fn):
776783
if name in node_attributes and node_attributes[name]:
777784
node_names.append(name)
778785
node_attrs.append(value)
@@ -790,45 +797,55 @@ def _pytree__unflatten(
790797
node_names, static_attrs = static
791798
obj = object.__new__(cls)
792799
vars_obj = vars(obj)
793-
vars_obj.update(zip(node_names, node_attrs, strict=True))
794-
vars_obj.update(static_attrs)
800+
if cls._pytree__has_int_keys:
801+
node_names = [
802+
str(name) if isinstance(name, int) else name for name in node_names
803+
]
804+
for name, value in zip(node_names, node_attrs, strict=True):
805+
object.__setattr__(obj, name, value)
806+
for name, value in static_attrs:
807+
object.__setattr__(obj, name, value)
795808
return obj
796809

797810
# -------------------------
798811
# Graph Definition
799812
# -------------------------
800813
def _graph_node_flatten(self):
801-
nodes = vars(self)
802-
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
814+
obj_items = vars(self).items()
815+
if self._pytree__has_int_keys:
816+
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
817+
nodes = sorted(obj_items, key=self._pytree__key_sort_fn)
803818
return nodes, type(self)
804819

805-
def _graph_node_set_key(self, key: str, value: tp.Any):
806-
if not isinstance(key, str):
807-
raise KeyError(f'Invalid key: {key!r}')
808-
elif (
809-
hasattr(self, key)
810-
and isinstance(variable := getattr(self, key), Variable)
811-
and isinstance(value, Variable)
812-
):
813-
variable.update_from_state(value)
814-
else:
815-
setattr(self, key, value)
820+
def _graph_node_set_key(self, key, value: tp.Any):
821+
if self._pytree__has_int_keys and isinstance(key, int):
822+
key = str(key)
823+
setattr(self, key, value)
816824

817-
def _graph_node_pop_key(self, key: str):
818-
if not isinstance(key, str):
819-
raise KeyError(f'Invalid key: {key!r}')
820-
return vars(self).pop(key)
825+
def _graph_node_pop_key(self, key):
826+
if self._pytree__has_int_keys and isinstance(key, int):
827+
key = str(key)
828+
value = getattr(self, key)
829+
delattr(self, key)
830+
return value
821831

822832
@staticmethod
823833
def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
824834
node = object.__new__(node_type)
825835
return node
826836

827837
def _graph_node_clear(self):
828-
vars(self).clear()
838+
for name in list(vars(self)):
839+
delattr(self, name)
829840

830841
def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
831-
vars(self).update(attributes)
842+
if self._pytree__has_int_keys:
843+
attributes = (
844+
(str(name) if isinstance(name, int) else name, value)
845+
for name, value in attributes
846+
)
847+
for name, value in attributes:
848+
object.__setattr__(self, name, value)
832849

833850
if tp.TYPE_CHECKING:
834851
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
@@ -845,3 +862,9 @@ def __init_subclass__(cls, **kwargs):
845862
f'{pytree!r} for type {cls}.'
846863
)
847864
super().__init_subclass__(pytree=pytree, **kwargs)
865+
866+
def _maybe_int(x):
867+
try:
868+
return int(x)
869+
except (ValueError, TypeError):
870+
return x

tests/nnx/partitioning_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ def test_get_paritition(self):
150150
d=5.0,
151151
)
152152

153-
# test Variables not shared
154-
self.assertIsNot(vars(m.a)[0], vars(m)['b'])
155-
156153
state = nnx.state(m, nnx.Variable)
157154
self.assertEqual(state['a'][0][...], m.a[0][...])
158155
self.assertEqual(state['a'][1][...], m.a[1][...])

0 commit comments

Comments
 (0)