Skip to content

Commit 6080725

Browse files
author
Flax Authors
committed
Merge pull request #4985 from google:no-setattr-metadata
PiperOrigin-RevId: 813964897
2 parents 8025b3d + c23304d commit 6080725

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

flax/nnx/spmd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def _add_axis(x: tp.Any):
4747
metadata = x.get_metadata()
4848
if 'sharding_names' in metadata and metadata['sharding_names']:
4949
sharding = metadata['sharding_names']
50-
x.sharding_names = insert_field(sharding, index, axis_name)
50+
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
5151

5252
for k, v in other_meta.items():
5353
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
54-
setattr(x, k, insert_field(t, index, v))
54+
x.set_metadata(k, insert_field(t, index, v))
5555

5656
assert isinstance(x, variablelib.Variable)
5757
x.add_axis(index, axis_name)
@@ -75,11 +75,13 @@ def remove_field(fields, index, value):
7575
def _remove_axis(x: tp.Any):
7676
if isinstance(x, variablelib.Variable):
7777
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
78-
x.sharding_names = remove_field(x.sharding_names, index, axis_name)
78+
x.set_metadata(
79+
sharding_names=remove_field(x.sharding_names, index, axis_name)
80+
)
7981

8082
for k, v in other_meta.items():
8183
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
82-
setattr(x, k, remove_field(t, index, v))
84+
x.set_metadata(k, remove_field(t, index, v))
8385

8486
x.remove_axis(index, axis_name)
8587
return x

flax/nnx/variablelib.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from flax import errors
3030
from flax.core import spmd as core_spmd
3131
from flax.nnx import filterlib, reprlib, tracers, visualization
32-
from flax.typing import Missing, PathParts, SizeBytes
32+
from flax.typing import MISSING, Missing, PathParts, SizeBytes
3333
import jax.tree_util as jtu
3434
import jax.numpy as jnp
3535
from jax._src.state.types import AbstractRef
@@ -172,7 +172,14 @@ class VariableMetadata(tp.Generic[A]):
172172
metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict)
173173

174174

175-
class Variable(tp.Generic[A], reprlib.Representable):
175+
class VariableMeta(type):
176+
def __new__(cls, cls_name, bases, attrs):
177+
if '__slots__' not in attrs:
178+
attrs['__slots__'] = ()
179+
return super().__new__(cls, cls_name, bases, attrs)
180+
181+
182+
class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta):
176183
"""The base class for all ``Variable`` types. Create custom ``Variable``
177184
types by subclassing this class. Numerous NNX graph functions can filter
178185
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
@@ -353,47 +360,61 @@ def has_ref(self) -> bool:
353360
@tp.overload
354361
def get_metadata(self) -> dict[str, tp.Any]: ...
355362
@tp.overload
356-
def get_metadata(self, name: str) -> tp.Any: ...
357-
def get_metadata(self, name: str | None = None):
363+
def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ...
364+
def get_metadata(
365+
self, name: str | None = None, default: tp.Any = MISSING
366+
) -> tp.Any:
358367
"""Get metadata for the Variable.
359368
360369
Args:
361370
name: The key of the metadata element to get. If not provided, returns
362371
the full metadata dictionary.
372+
default: The default value to return if the metadata key is not found. If
373+
not provided and the key is not found, raises a KeyError.
363374
"""
364375
if name is None:
365376
return self._var_metadata
377+
if name not in self._var_metadata and not isinstance(default, Missing):
378+
return default
366379
return self._var_metadata[name]
367380

368381
@tp.overload
369382
def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ...
370383
@tp.overload
384+
def set_metadata(self, name: str, value: tp.Any, /) -> None: ...
385+
@tp.overload
371386
def set_metadata(self, **metadata: tp.Any) -> None: ...
372387
def set_metadata(self, *args, **kwargs) -> None:
373388
"""Set metadata for the Variable.
374389
375-
`set_metadata` can be called in two ways:
390+
`set_metadata` can be called in 3 ways:
376391
377392
1. By passing a dictionary of metadata as the first argument, this will replace
378393
the entire Variable's metadata.
379-
2. By using keyword arguments, these will be merged into the existing Variable's
380-
metadata.
394+
2. By passing a name and value as the first two arguments, this will set
395+
the metadata entry for the given name to the given value.
396+
3. By using keyword arguments, this will update the Variable's metadata
397+
with the provided key-value pairs.
381398
"""
382399
if not self._trace_state.is_valid():
383400
raise errors.TraceContextError(
384401
f'Cannot mutate {type(self).__name__} from a different trace level'
385402
)
386-
if not (bool(args) ^ bool(kwargs)):
403+
if args and kwargs:
387404
raise TypeError(
388-
'set_metadata takes either a single dict argument or keyword arguments'
405+
'Cannot mix positional and keyword arguments in set_metadata'
389406
)
390407
if len(args) == 1:
391-
self._var_metadata = args[0]
408+
self._var_metadata = dict(args[0])
409+
elif len(args) == 2:
410+
name, value = args
411+
self._var_metadata[name] = value
392412
elif kwargs:
393413
self._var_metadata.update(kwargs)
394414
else:
395415
raise TypeError(
396-
f'set_metadata takes either 1 argument or 1 or more keyword arguments, got args={args}, kwargs={kwargs}'
416+
f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, '
417+
f'got args={args}, kwargs={kwargs}'
397418
)
398419

399420
def copy_from(self, other: Variable[A]) -> None:

tests/nnx/variable_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ def test_get_set_metadata(self):
128128
self.assertEqual(v.get_metadata(), {'b': 3, 'c': 4})
129129
self.assertEqual(v.get_metadata('b'), 3)
130130
self.assertEqual(v.get_metadata('c'), 4)
131+
c = v.get_metadata('c')
132+
self.assertEqual(c, 4)
133+
x = v.get_metadata('x', default=10)
134+
self.assertEqual(x, 10)
131135

132136

133137
if __name__ == '__main__':

0 commit comments

Comments
 (0)