Skip to content

Commit 777a8e1

Browse files
committed
Variable.__setattr__ no longer sets metadata
1 parent a9dccc0 commit 777a8e1

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

flax/nnx/spmd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def _add_axis(x: tp.Any):
4747
metadata = x.get_metadata()
4848
if 'sharding_names' in metadata and metadata['sharding_names']:
4949
sharding = metadata['sharding_names']
50-
x.sharding_names = insert_field(sharding, index, axis_name)
50+
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
5151

5252
for k, v in other_meta.items():
5353
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
54-
setattr(x, k, insert_field(t, index, v))
54+
x.set_metadata(k, insert_field(t, index, v))
5555

5656
assert isinstance(x, variablelib.Variable)
5757
x.add_axis(index, axis_name)
@@ -75,11 +75,13 @@ def remove_field(fields, index, value):
7575
def _remove_axis(x: tp.Any):
7676
if isinstance(x, variablelib.Variable):
7777
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
78-
x.sharding_names = remove_field(x.sharding_names, index, axis_name)
78+
x.set_metadata(
79+
sharding_names=remove_field(x.sharding_names, index, axis_name)
80+
)
7981

8082
for k, v in other_meta.items():
8183
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
82-
setattr(x, k, remove_field(t, index, v))
84+
x.set_metadata(k, remove_field(t, index, v))
8385

8486
x.remove_axis(index, axis_name)
8587
return x

flax/nnx/variablelib.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,14 @@ class VariableMetadata(tp.Generic[A]):
167167
metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict)
168168

169169

170-
class Variable(tp.Generic[A], reprlib.Representable):
170+
class VariableMeta(type):
171+
def __new__(cls, cls_name, bases, attrs):
172+
if '__slots__' not in attrs:
173+
attrs['__slots__'] = ()
174+
return super().__new__(cls, cls_name, bases, attrs)
175+
176+
177+
class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta):
171178
"""The base class for all ``Variable`` types. Create custom ``Variable``
172179
types by subclassing this class. Numerous NNX graph functions can filter
173180
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
@@ -303,15 +310,13 @@ def __setattr__(self, name: str, value: tp.Any):
303310
raise errors.TraceContextError(
304311
f'Cannot mutate {type(self).__name__} from a different trace level'
305312
)
306-
if (
307-
name == 'value'
308-
or name == 'raw_value'
309-
or name == '_var_metadata'
310-
or name == '_trace_state'
311-
):
313+
try:
312314
object.__setattr__(self, name, value)
313-
else:
314-
self._var_metadata[name] = value
315+
except AttributeError:
316+
raise AttributeError(
317+
f'Cannot set attribute {name}.\n'
318+
f"To set Variable metadata use: `variable.set_metadata('{name}', value)`."
319+
)
315320

316321
def __delattr__(self, name: str):
317322
if not self._trace_state.is_valid():
@@ -363,32 +368,40 @@ def get_metadata(self, name: str | None = None):
363368
@tp.overload
364369
def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ...
365370
@tp.overload
371+
def set_metadata(self, name: str, value: tp.Any, /) -> None: ...
372+
@tp.overload
366373
def set_metadata(self, **metadata: tp.Any) -> None: ...
367374
def set_metadata(self, *args, **kwargs) -> None:
368375
"""Set metadata for the Variable.
369376
370-
`set_metadata` can be called in two ways:
377+
`set_metadata` can be called in 3 ways:
371378
372379
1. By passing a dictionary of metadata as the first argument, this will replace
373380
the entire Variable's metadata.
374-
2. By using keyword arguments, these will be merged into the existing Variable's
375-
metadata.
381+
2. By passing a name and value as the first two arguments, this will set
382+
the metadata entry for the given name to the given value.
383+
3. By using keyword arguments, this will update the Variable's metadata
384+
with the provided key-value pairs.
376385
"""
377386
if not self._trace_state.is_valid():
378387
raise errors.TraceContextError(
379388
f'Cannot mutate {type(self).__name__} from a different trace level'
380389
)
381-
if not (bool(args) ^ bool(kwargs)):
390+
if args and kwargs:
382391
raise TypeError(
383-
'set_metadata takes either a single dict argument or keyword arguments'
392+
'Cannot mix positional and keyword arguments in set_metadata'
384393
)
385394
if len(args) == 1:
386-
self._var_metadata = args[0]
395+
self._var_metadata = dict(args[0])
396+
elif len(args) == 2:
397+
name, value = args
398+
self._var_metadata[name] = value
387399
elif kwargs:
388400
self._var_metadata.update(kwargs)
389401
else:
390402
raise TypeError(
391-
f'set_metadata takes either 1 argument or 1 or more keyword arguments, got args={args}, kwargs={kwargs}'
403+
f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, '
404+
f'got args={args}, kwargs={kwargs}'
392405
)
393406

394407
def copy_from(self, other: Variable[A]) -> None:

0 commit comments

Comments
 (0)