|
29 | 29 | from flax import errors |
30 | 30 | from flax.core import spmd as core_spmd |
31 | 31 | from flax.nnx import filterlib, reprlib, tracers, visualization |
32 | | -from flax.typing import Missing, PathParts, SizeBytes |
| 32 | +from flax.typing import MISSING, Missing, PathParts, SizeBytes |
33 | 33 | import jax.tree_util as jtu |
34 | 34 | import jax.numpy as jnp |
35 | 35 | from jax._src.state.types import AbstractRef |
@@ -172,7 +172,14 @@ class VariableMetadata(tp.Generic[A]): |
172 | 172 | metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) |
173 | 173 |
|
174 | 174 |
|
175 | | -class Variable(tp.Generic[A], reprlib.Representable): |
| 175 | +class VariableMeta(type): |
| 176 | + def __new__(cls, cls_name, bases, attrs): |
| 177 | + if '__slots__' not in attrs: |
| 178 | + attrs['__slots__'] = () |
| 179 | + return super().__new__(cls, cls_name, bases, attrs) |
| 180 | + |
| 181 | + |
| 182 | +class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): |
176 | 183 | """The base class for all ``Variable`` types. Create custom ``Variable`` |
177 | 184 | types by subclassing this class. Numerous NNX graph functions can filter |
178 | 185 | for specific ``Variable`` types, for example, :func:`split`, :func:`state`, |
@@ -353,47 +360,61 @@ def has_ref(self) -> bool: |
353 | 360 | @tp.overload |
354 | 361 | def get_metadata(self) -> dict[str, tp.Any]: ... |
355 | 362 | @tp.overload |
356 | | - def get_metadata(self, name: str) -> tp.Any: ... |
357 | | - def get_metadata(self, name: str | None = None): |
| 363 | + def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ... |
| 364 | + def get_metadata( |
| 365 | + self, name: str | None = None, default: tp.Any = MISSING |
| 366 | + ) -> tp.Any: |
358 | 367 | """Get metadata for the Variable. |
359 | 368 |
|
360 | 369 | Args: |
361 | 370 | name: The key of the metadata element to get. If not provided, returns |
362 | 371 | the full metadata dictionary. |
| 372 | + default: The default value to return if the metadata key is not found. If |
| 373 | + not provided and the key is not found, raises a KeyError. |
363 | 374 | """ |
364 | 375 | if name is None: |
365 | 376 | return self._var_metadata |
| 377 | + if name not in self._var_metadata and not isinstance(default, Missing): |
| 378 | + return default |
366 | 379 | return self._var_metadata[name] |
367 | 380 |
|
368 | 381 | @tp.overload |
369 | 382 | def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... |
370 | 383 | @tp.overload |
| 384 | + def set_metadata(self, name: str, value: tp.Any, /) -> None: ... |
| 385 | + @tp.overload |
371 | 386 | def set_metadata(self, **metadata: tp.Any) -> None: ... |
372 | 387 | def set_metadata(self, *args, **kwargs) -> None: |
373 | 388 | """Set metadata for the Variable. |
374 | 389 |
|
375 | | - `set_metadata` can be called in two ways: |
| 390 | + `set_metadata` can be called in 3 ways: |
376 | 391 |
|
377 | 392 | 1. By passing a dictionary of metadata as the first argument, this will replace |
378 | 393 | the entire Variable's metadata. |
379 | | - 2. By using keyword arguments, these will be merged into the existing Variable's |
380 | | - metadata. |
| 394 | + 2. By passing a name and value as the first two arguments, this will set |
| 395 | + the metadata entry for the given name to the given value. |
| 396 | + 3. By using keyword arguments, this will update the Variable's metadata |
| 397 | + with the provided key-value pairs. |
381 | 398 | """ |
382 | 399 | if not self._trace_state.is_valid(): |
383 | 400 | raise errors.TraceContextError( |
384 | 401 | f'Cannot mutate {type(self).__name__} from a different trace level' |
385 | 402 | ) |
386 | | - if not (bool(args) ^ bool(kwargs)): |
| 403 | + if args and kwargs: |
387 | 404 | raise TypeError( |
388 | | - 'set_metadata takes either a single dict argument or keyword arguments' |
| 405 | + 'Cannot mix positional and keyword arguments in set_metadata' |
389 | 406 | ) |
390 | 407 | if len(args) == 1: |
391 | | - self._var_metadata = args[0] |
| 408 | + self._var_metadata = dict(args[0]) |
| 409 | + elif len(args) == 2: |
| 410 | + name, value = args |
| 411 | + self._var_metadata[name] = value |
392 | 412 | elif kwargs: |
393 | 413 | self._var_metadata.update(kwargs) |
394 | 414 | else: |
395 | 415 | raise TypeError( |
396 | | - f'set_metadata takes either 1 argument or 1 or more keyword arguments, got args={args}, kwargs={kwargs}' |
| 416 | + f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, ' |
| 417 | + f'got args={args}, kwargs={kwargs}' |
397 | 418 | ) |
398 | 419 |
|
399 | 420 | def copy_from(self, other: Variable[A]) -> None: |
|
0 commit comments