@@ -452,25 +452,49 @@ def __init__(self, cls: "Message"):
452
452
self .oneof_group_by_field = by_field
453
453
self .oneof_field_by_group = by_group
454
454
455
+ def init_default_gen (self ):
456
+ default_gen = {}
457
+
458
+ for field in dataclasses .fields (self .cls ):
459
+ meta = FieldMetadata .get (field )
460
+ default_gen [field .name ] = self .cls ._get_field_default_gen (field , meta )
461
+
462
+ self .default_gen = default_gen
463
+
464
+ def init_cls_by_field (self ):
465
+ field_cls = {}
466
+
467
+ for field in dataclasses .fields (self .cls ):
468
+ meta = FieldMetadata .get (field )
469
+ if meta .proto_type == TYPE_MAP :
470
+ assert meta .map_types
471
+ kt = self .cls ._cls_for (field , index = 0 )
472
+ vt = self .cls ._cls_for (field , index = 1 )
473
+ Entry = dataclasses .make_dataclass (
474
+ "Entry" ,
475
+ [
476
+ ("key" , kt , dataclass_field (1 , meta .map_types [0 ])),
477
+ ("value" , vt , dataclass_field (2 , meta .map_types [1 ])),
478
+ ],
479
+ bases = (Message ,),
480
+ )
481
+ make_protoclass (Entry )
482
+ field_cls [field .name ] = Entry
483
+ field_cls [field .name + ".value" ] = vt
484
+ else :
485
+ field_cls [field .name ] = self .cls ._cls_for (field )
486
+
487
+ self .cls_by_field = field_cls
488
+
455
489
def __getattr__ (self , item ):
456
490
# Lazy init because forward reference classes may not be available at the beginning.
457
491
if item == 'default_gen' :
458
- defaults = {}
459
- for field in dataclasses .fields (self .cls ):
460
- meta = FieldMetadata .get (field )
461
- defaults [field .name ] = self .cls ._get_field_default_gen (field , meta )
462
-
463
- self .default_gen = defaults # __getattr__ won't be called next time
464
- return defaults
492
+ self .init_default_gen ()
493
+ return self .default_gen
465
494
466
495
if item == 'cls_by_field' :
467
- field_cls = {}
468
- for field in dataclasses .fields (self .cls ):
469
- meta = FieldMetadata .get (field )
470
- field_cls [field .name ] = self .cls ._type_hint (field .name )
471
-
472
- self .cls_by_field = field_cls # __getattr__ won't be called next time
473
- return field_cls
496
+ self .init_cls_by_field ()
497
+ return self .cls_by_field
474
498
475
499
476
500
def make_protoclass (cls ):
@@ -619,12 +643,13 @@ def _type_hint(cls, field_name: str) -> Type:
619
643
type_hints = get_type_hints (cls , vars (module ))
620
644
return type_hints [field_name ]
621
645
622
- def _cls_for (self , field : dataclasses .Field , index : int = 0 ) -> Type :
646
+ @classmethod
647
+ def _cls_for (cls , field : dataclasses .Field , index : int = 0 ) -> Type :
623
648
"""Get the message class for a field from the type hints."""
624
- cls = self . _betterproto . cls_by_field [ field .name ]
625
- if hasattr (cls , "__args__" ) and index >= 0 :
626
- cls = cls .__args__ [index ]
627
- return cls
649
+ field_cls = cls . _type_hint ( field .name )
650
+ if hasattr (field_cls , "__args__" ) and index >= 0 :
651
+ field_cls = field_cls .__args__ [index ]
652
+ return field_cls
628
653
629
654
def _get_field_default (self , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
630
655
return self ._betterproto .default_gen [field .name ]()
@@ -680,7 +705,7 @@ def _postprocess_single(
680
705
if meta .proto_type == TYPE_STRING :
681
706
value = value .decode ("utf-8" )
682
707
elif meta .proto_type == TYPE_MESSAGE :
683
- cls = self ._cls_for ( field )
708
+ cls = self ._betterproto . cls_by_field [ field . name ]
684
709
685
710
if cls == datetime :
686
711
value = _Timestamp ().parse (value ).to_datetime ()
@@ -694,21 +719,7 @@ def _postprocess_single(
694
719
value = cls ().parse (value )
695
720
value ._serialized_on_wire = True
696
721
elif meta .proto_type == TYPE_MAP :
697
- # TODO: This is slow, use a cache to make it faster since each
698
- # key/value pair will recreate the class.
699
- assert meta .map_types
700
- kt = self ._cls_for (field , index = 0 )
701
- vt = self ._cls_for (field , index = 1 )
702
- Entry = dataclasses .make_dataclass (
703
- "Entry" ,
704
- [
705
- ("key" , kt , dataclass_field (1 , meta .map_types [0 ])),
706
- ("value" , vt , dataclass_field (2 , meta .map_types [1 ])),
707
- ],
708
- bases = (Message ,),
709
- )
710
- make_protoclass (Entry )
711
- value = Entry ().parse (value )
722
+ value = self ._betterproto .cls_by_field [field .name ]().parse (value )
712
723
713
724
return value
714
725
@@ -823,7 +834,7 @@ def to_dict(
823
834
else :
824
835
output [cased_name ] = b64encode (v ).decode ("utf8" )
825
836
elif meta .proto_type == TYPE_ENUM :
826
- enum_values = list (self ._cls_for ( field ) ) # type: ignore
837
+ enum_values = list (self ._betterproto . cls_by_field [ field . name ] ) # type: ignore
827
838
if isinstance (v , list ):
828
839
output [cased_name ] = [enum_values [e ].name for e in v ]
829
840
else :
@@ -849,7 +860,7 @@ def from_dict(self: T, value: dict) -> T:
849
860
if meta .proto_type == "message" :
850
861
v = getattr (self , field .name )
851
862
if isinstance (v , list ):
852
- cls = self ._cls_for ( field )
863
+ cls = self ._betterproto . cls_by_field [ field . name ]
853
864
for i in range (len (value [key ])):
854
865
v .append (cls ().from_dict (value [key ][i ]))
855
866
elif isinstance (v , datetime ):
@@ -866,7 +877,7 @@ def from_dict(self: T, value: dict) -> T:
866
877
v .from_dict (value [key ])
867
878
elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
868
879
v = getattr (self , field .name )
869
- cls = self ._cls_for ( field , index = 1 )
880
+ cls = self ._betterproto . cls_by_field [ field . name + ".value" ]
870
881
for k in value [key ]:
871
882
v [k ] = cls ().from_dict (value [key ][k ])
872
883
else :
@@ -882,7 +893,7 @@ def from_dict(self: T, value: dict) -> T:
882
893
else :
883
894
v = b64decode (value [key ])
884
895
elif meta .proto_type == TYPE_ENUM :
885
- enum_cls = self ._cls_for ( field )
896
+ enum_cls = self ._betterproto . cls_by_field [ field . name ]
886
897
if isinstance (v , list ):
887
898
v = [enum_cls .from_string (e ) for e in v ]
888
899
elif isinstance (v , str ):
0 commit comments