@@ -167,7 +167,14 @@ class VariableMetadata(tp.Generic[A]):
167167 metadata : tp .Mapping [str , tp .Any ] = dataclasses .field (default_factory = dict )
168168
169169
170- class Variable (tp .Generic [A ], reprlib .Representable ):
170+ class VariableMeta (type ):
171+ def __new__ (cls , cls_name , bases , attrs ):
172+ if '__slots__' not in attrs :
173+ attrs ['__slots__' ] = ()
174+ return super ().__new__ (cls , cls_name , bases , attrs )
175+
176+
177+ class Variable (tp .Generic [A ], reprlib .Representable , metaclass = VariableMeta ):
171178 """The base class for all ``Variable`` types. Create custom ``Variable``
172179 types by subclassing this class. Numerous NNX graph functions can filter
173180 for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
@@ -303,15 +310,13 @@ def __setattr__(self, name: str, value: tp.Any):
303310 raise errors .TraceContextError (
304311 f'Cannot mutate { type (self ).__name__ } from a different trace level'
305312 )
306- if (
307- name == 'value'
308- or name == 'raw_value'
309- or name == '_var_metadata'
310- or name == '_trace_state'
311- ):
313+ try :
312314 object .__setattr__ (self , name , value )
313- else :
314- self ._var_metadata [name ] = value
315+ except AttributeError :
316+ raise AttributeError (
317+ f'Cannot set attribute { name } .\n '
318+ f"To set Variable metadata use: `variable.set_metadata('{ name } ', value)`."
319+ )
315320
316321 def __delattr__ (self , name : str ):
317322 if not self ._trace_state .is_valid ():
@@ -363,32 +368,40 @@ def get_metadata(self, name: str | None = None):
363368 @tp .overload
364369 def set_metadata (self , metadata : dict [str , tp .Any ], / ) -> None : ...
365370 @tp .overload
371+ def set_metadata (self , name : str , value : tp .Any , / ) -> None : ...
372+ @tp .overload
366373 def set_metadata (self , ** metadata : tp .Any ) -> None : ...
367374 def set_metadata (self , * args , ** kwargs ) -> None :
368375 """Set metadata for the Variable.
369376
370- `set_metadata` can be called in two ways:
377+ `set_metadata` can be called in 3 ways:
371378
372379 1. By passing a dictionary of metadata as the first argument, this will replace
373380 the entire Variable's metadata.
374- 2. By using keyword arguments, these will be merged into the existing Variable's
375- metadata.
381+ 2. By passing a name and value as the first two arguments, this will set
382+ the metadata entry for the given name to the given value.
383+ 3. By using keyword arguments, this will update the Variable's metadata
384+ with the provided key-value pairs.
376385 """
377386 if not self ._trace_state .is_valid ():
378387 raise errors .TraceContextError (
379388 f'Cannot mutate { type (self ).__name__ } from a different trace level'
380389 )
381- if not ( bool ( args ) ^ bool ( kwargs )) :
390+ if args and kwargs :
382391 raise TypeError (
383- 'set_metadata takes either a single dict argument or keyword arguments'
392+ 'Cannot mix positional and keyword arguments in set_metadata '
384393 )
385394 if len (args ) == 1 :
386- self ._var_metadata = args [0 ]
395+ self ._var_metadata = dict (args [0 ])
396+ elif len (args ) == 2 :
397+ name , value = args
398+ self ._var_metadata [name ] = value
387399 elif kwargs :
388400 self ._var_metadata .update (kwargs )
389401 else :
390402 raise TypeError (
391- f'set_metadata takes either 1 argument or 1 or more keyword arguments, got args={ args } , kwargs={ kwargs } '
403+ f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, '
404+ f'got args={ args } , kwargs={ kwargs } '
392405 )
393406
394407 def copy_from (self , other : Variable [A ]) -> None :
0 commit comments