Skip to content

Commit 1f7f390

Browse files
committed
Cache resolved classes for fields, so that there's no new data classes generated while deserializing.
1 parent 3d001a2 commit 1f7f390

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

betterproto/__init__.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -452,25 +452,49 @@ def __init__(self, cls: "Message"):
452452
self.oneof_group_by_field = by_field
453453
self.oneof_field_by_group = by_group
454454

455+
def init_default_gen(self):
456+
default_gen = {}
457+
458+
for field in dataclasses.fields(self.cls):
459+
meta = FieldMetadata.get(field)
460+
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
461+
462+
self.default_gen = default_gen
463+
464+
def init_cls_by_field(self):
465+
field_cls = {}
466+
467+
for field in dataclasses.fields(self.cls):
468+
meta = FieldMetadata.get(field)
469+
if meta.proto_type == TYPE_MAP:
470+
assert meta.map_types
471+
kt = self.cls._cls_for(field, index=0)
472+
vt = self.cls._cls_for(field, index=1)
473+
Entry = dataclasses.make_dataclass(
474+
"Entry",
475+
[
476+
("key", kt, dataclass_field(1, meta.map_types[0])),
477+
("value", vt, dataclass_field(2, meta.map_types[1])),
478+
],
479+
bases=(Message,),
480+
)
481+
make_protoclass(Entry)
482+
field_cls[field.name] = Entry
483+
field_cls[field.name + ".value"] = vt
484+
else:
485+
field_cls[field.name] = self.cls._cls_for(field)
486+
487+
self.cls_by_field = field_cls
488+
455489
def __getattr__(self, item):
456490
# Lazy init because forward reference classes may not be available at the beginning.
457491
if item == 'default_gen':
458-
defaults = {}
459-
for field in dataclasses.fields(self.cls):
460-
meta = FieldMetadata.get(field)
461-
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
462-
463-
self.default_gen = defaults # __getattr__ won't be called next time
464-
return defaults
492+
self.init_default_gen()
493+
return self.default_gen
465494

466495
if item == 'cls_by_field':
467-
field_cls = {}
468-
for field in dataclasses.fields(self.cls):
469-
meta = FieldMetadata.get(field)
470-
field_cls[field.name] = self.cls._type_hint(field.name)
471-
472-
self.cls_by_field = field_cls # __getattr__ won't be called next time
473-
return field_cls
496+
self.init_cls_by_field()
497+
return self.cls_by_field
474498

475499

476500
def make_protoclass(cls):
@@ -619,12 +643,13 @@ def _type_hint(cls, field_name: str) -> Type:
619643
type_hints = get_type_hints(cls, vars(module))
620644
return type_hints[field_name]
621645

622-
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
646+
@classmethod
647+
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
623648
"""Get the message class for a field from the type hints."""
624-
cls = self._betterproto.cls_by_field[field.name]
625-
if hasattr(cls, "__args__") and index >= 0:
626-
cls = cls.__args__[index]
627-
return cls
649+
field_cls = cls._type_hint(field.name)
650+
if hasattr(field_cls, "__args__") and index >= 0:
651+
field_cls = field_cls.__args__[index]
652+
return field_cls
628653

629654
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
630655
return self._betterproto.default_gen[field.name]()
@@ -680,7 +705,7 @@ def _postprocess_single(
680705
if meta.proto_type == TYPE_STRING:
681706
value = value.decode("utf-8")
682707
elif meta.proto_type == TYPE_MESSAGE:
683-
cls = self._cls_for(field)
708+
cls = self._betterproto.cls_by_field[field.name]
684709

685710
if cls == datetime:
686711
value = _Timestamp().parse(value).to_datetime()
@@ -694,21 +719,7 @@ def _postprocess_single(
694719
value = cls().parse(value)
695720
value._serialized_on_wire = True
696721
elif meta.proto_type == TYPE_MAP:
697-
# TODO: This is slow, use a cache to make it faster since each
698-
# key/value pair will recreate the class.
699-
assert meta.map_types
700-
kt = self._cls_for(field, index=0)
701-
vt = self._cls_for(field, index=1)
702-
Entry = dataclasses.make_dataclass(
703-
"Entry",
704-
[
705-
("key", kt, dataclass_field(1, meta.map_types[0])),
706-
("value", vt, dataclass_field(2, meta.map_types[1])),
707-
],
708-
bases=(Message,),
709-
)
710-
make_protoclass(Entry)
711-
value = Entry().parse(value)
722+
value = self._betterproto.cls_by_field[field.name]().parse(value)
712723

713724
return value
714725

@@ -823,7 +834,7 @@ def to_dict(
823834
else:
824835
output[cased_name] = b64encode(v).decode("utf8")
825836
elif meta.proto_type == TYPE_ENUM:
826-
enum_values = list(self._cls_for(field)) # type: ignore
837+
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
827838
if isinstance(v, list):
828839
output[cased_name] = [enum_values[e].name for e in v]
829840
else:
@@ -849,7 +860,7 @@ def from_dict(self: T, value: dict) -> T:
849860
if meta.proto_type == "message":
850861
v = getattr(self, field.name)
851862
if isinstance(v, list):
852-
cls = self._cls_for(field)
863+
cls = self._betterproto.cls_by_field[field.name]
853864
for i in range(len(value[key])):
854865
v.append(cls().from_dict(value[key][i]))
855866
elif isinstance(v, datetime):
@@ -866,7 +877,7 @@ def from_dict(self: T, value: dict) -> T:
866877
v.from_dict(value[key])
867878
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
868879
v = getattr(self, field.name)
869-
cls = self._cls_for(field, index=1)
880+
cls = self._betterproto.cls_by_field[field.name + ".value"]
870881
for k in value[key]:
871882
v[k] = cls().from_dict(value[key][k])
872883
else:
@@ -882,7 +893,7 @@ def from_dict(self: T, value: dict) -> T:
882893
else:
883894
v = b64decode(value[key])
884895
elif meta.proto_type == TYPE_ENUM:
885-
enum_cls = self._cls_for(field)
896+
enum_cls = self._betterproto.cls_by_field[field.name]
886897
if isinstance(v, list):
887898
v = [enum_cls.from_string(e) for e in v]
888899
elif isinstance(v, str):

0 commit comments

Comments
 (0)