@@ -340,16 +340,18 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
340340def _graph_node_meta_call (cls : tp .Type [P ], * args , ** kwargs ) -> P :
341341 node = cls .__new__ (cls , * args , ** kwargs )
342342 vars_obj = vars (node )
343- vars_obj [ '_pytree__state' ] = PytreeState ()
344- vars_obj [ '_pytree__nodes' ] = cls ._pytree__nodes
343+ object . __setattr__ ( node , '_pytree__state' , PytreeState () )
344+ object . __setattr__ ( node , '_pytree__nodes' , cls ._pytree__nodes )
345345 cls ._pytree_meta_construct (node , * args , ** kwargs )
346346 if cls ._pytree__is_pytree :
347347 missing : dict [str , bool ] = {}
348348 for name , value in vars (node ).items ():
349- if name not in vars_obj [ ' _pytree__nodes' ] :
349+ if name not in node . _pytree__nodes :
350350 missing [name ] = is_data (value )
351351 if missing :
352- vars_obj ['_pytree__nodes' ] = vars_obj ['_pytree__nodes' ].update (missing )
352+ object .__setattr__ (
353+ node , '_pytree__nodes' , node ._pytree__nodes .update (missing )
354+ )
353355 check_pytree (node )
354356
355357 return node
@@ -500,11 +502,10 @@ def _setattr(self, name, value: tp.Any) -> None:
500502 if name not in self ._pytree__nodes or (
501503 explicit and self ._pytree__nodes [name ] != data
502504 ):
503- vars (self )['_pytree__nodes' ] = self ._pytree__nodes .update ({name : data })
504- if isinstance (name , str ):
505- object .__setattr__ (self , name , value )
506- else :
507- vars (self )[name ] = value
505+ object .__setattr__ (
506+ self , '_pytree__nodes' , self ._pytree__nodes .update ({name : data })
507+ )
508+ object .__setattr__ (self , name , value )
508509
509510 def _check_value (self , key , value , new_status : AttributeStatus | None ):
510511 def _has_arrays (leaves ):
@@ -739,20 +740,26 @@ def __getstate__(self):
739740 return vars (self ).copy ()
740741
741742 def __setstate__ (self , state ):
742- vars (self ).update (state )
743+ for key , value in state .items ():
744+ object .__setattr__ (self , key , value )
743745
744746 # -------------------------
745747 # Pytree Definition
746748 # -------------------------
747- _pytree__key_sort_fn : tp . Callable | None = None
749+ _pytree__has_int_keys : bool = False
748750
749751 def _pytree__flatten_with_paths (self ):
750- obj_vars = vars (self )
752+ obj_items = vars (self ).items ()
753+ if self ._pytree__has_int_keys :
754+ obj_items = ((_maybe_int (name ), value ) for name , value in obj_items )
755+ key_fn = graph ._type_aware_sort
756+ else :
757+ key_fn = None
751758 node_attributes = self ._pytree__nodes
752759 node_names : list [str ] = []
753760 node_attrs : list [tuple [tp .Any , tp .Any ]] = []
754761 static_attrs : list [tuple [str , tp .Any ]] = []
755- for name , value in sorted (obj_vars . items () , key = self . _pytree__key_sort_fn ):
762+ for name , value in sorted (obj_items , key = key_fn ):
756763 if name in node_attributes and node_attributes [name ]:
757764 node_names .append (name )
758765 node_attrs .append ((
@@ -767,12 +774,17 @@ def _pytree__flatten_with_paths(self):
767774 return node_attrs , (tuple (node_names ), tuple (static_attrs ))
768775
769776 def _pytree__flatten (self ):
770- obj_vars = vars (self )
777+ obj_items = vars (self ).items ()
778+ if self ._pytree__has_int_keys :
779+ obj_items = ((_maybe_int (name ), value ) for name , value in obj_items )
780+ key_fn = graph ._type_aware_sort
781+ else :
782+ key_fn = None
771783 node_attributes = self ._pytree__nodes
772784 node_names : list [str ] = []
773785 node_attrs : list [tp .Any ] = []
774786 static_attrs : list [tuple [str , tp .Any ]] = []
775- for name , value in sorted (obj_vars . items () , key = self . _pytree__key_sort_fn ):
787+ for name , value in sorted (obj_items , key = key_fn ):
776788 if name in node_attributes and node_attributes [name ]:
777789 node_names .append (name )
778790 node_attrs .append (value )
@@ -790,45 +802,58 @@ def _pytree__unflatten(
790802 node_names , static_attrs = static
791803 obj = object .__new__ (cls )
792804 vars_obj = vars (obj )
793- vars_obj .update (zip (node_names , node_attrs , strict = True ))
794- vars_obj .update (static_attrs )
805+ if cls ._pytree__has_int_keys :
806+ node_names = [
807+ str (name ) if isinstance (name , int ) else name for name in node_names
808+ ]
809+ for name , value in zip (node_names , node_attrs , strict = True ):
810+ object .__setattr__ (obj , name , value )
811+ for name , value in static_attrs :
812+ object .__setattr__ (obj , name , value )
795813 return obj
796814
797815 # -------------------------
798816 # Graph Definition
799817 # -------------------------
800818 def _graph_node_flatten (self ):
801- nodes = vars (self )
802- nodes = sorted (nodes .items (), key = self ._pytree__key_sort_fn )
819+ obj_items = vars (self ).items ()
820+ if self ._pytree__has_int_keys :
821+ obj_items = ((_maybe_int (name ), value ) for name , value in obj_items )
822+ key_fn = graph ._type_aware_sort
823+ else :
824+ key_fn = None
825+ nodes = sorted (obj_items , key = key_fn )
803826 return nodes , type (self )
804827
805- def _graph_node_set_key (self , key : str , value : tp .Any ):
806- if not isinstance (key , str ):
807- raise KeyError (f'Invalid key: { key !r} ' )
808- elif (
809- hasattr (self , key )
810- and isinstance (variable := getattr (self , key ), Variable )
811- and isinstance (value , Variable )
812- ):
813- variable .update_from_state (value )
814- else :
815- setattr (self , key , value )
828+ def _graph_node_set_key (self , key , value : tp .Any ):
829+ if self ._pytree__has_int_keys and isinstance (key , int ):
830+ key = str (key )
831+ setattr (self , key , value )
816832
817- def _graph_node_pop_key (self , key : str ):
818- if not isinstance (key , str ):
819- raise KeyError (f'Invalid key: { key !r} ' )
820- return vars (self ).pop (key )
833+ def _graph_node_pop_key (self , key ):
834+ if self ._pytree__has_int_keys and isinstance (key , int ):
835+ key = str (key )
836+ value = getattr (self , key )
837+ delattr (self , key )
838+ return value
821839
822840 @staticmethod
823841 def _graph_node_create_empty (node_type : tp .Type [P ]) -> P :
824842 node = object .__new__ (node_type )
825843 return node
826844
827845 def _graph_node_clear (self ):
828- vars (self ).clear ()
846+ for name in list (vars (self )):
847+ delattr (self , name )
829848
830849 def _graph_node_init (self , attributes : tp .Iterable [tuple [str , tp .Any ]]):
831- vars (self ).update (attributes )
850+ if self ._pytree__has_int_keys :
851+ attributes = (
852+ (str (name ) if isinstance (name , int ) else name , value )
853+ for name , value in attributes
854+ )
855+ for name , value in attributes :
856+ object .__setattr__ (self , name , value )
832857
833858 if tp .TYPE_CHECKING :
834859 def __call__ (self , * args : tp .Any , ** kwargs : tp .Any ) -> tp .Any : ...
@@ -845,3 +870,9 @@ def __init_subclass__(cls, **kwargs):
845870 f'{ pytree !r} for type { cls } .'
846871 )
847872 super ().__init_subclass__ (pytree = pytree , ** kwargs )
873+
874+ def _maybe_int (x ):
875+ try :
876+ return int (x )
877+ except (ValueError , TypeError ):
878+ return x
0 commit comments