Skip to content

Commit d484405

Browse files
committed
Replace extra decorator with property and lazy initialization so that it is backward compatible.
1 parent c0b05af commit d484405

File tree

3 files changed

+46
-50
lines changed

3 files changed

+46
-50
lines changed

betterproto/__init__.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
431431

432432

433433
class ProtoClassMetadata:
434-
cls: "Message"
434+
cls: Type["Message"]
435435

436-
def __init__(self, cls: "Message"):
436+
def __init__(self, cls: Type["Message"]):
437437
self.cls = cls
438438
by_field = {}
439439
by_group = {}
@@ -450,6 +450,9 @@ def __init__(self, cls: "Message"):
450450
self.oneof_group_by_field = by_field
451451
self.oneof_field_by_group = by_group
452452

453+
self.init_default_gen()
454+
self.init_cls_by_field()
455+
453456
def init_default_gen(self):
454457
default_gen = {}
455458

@@ -476,34 +479,13 @@ def init_cls_by_field(self):
476479
],
477480
bases=(Message,),
478481
)
479-
make_protoclass(Entry)
480482
field_cls[field.name] = Entry
481483
field_cls[field.name + ".value"] = vt
482484
else:
483485
field_cls[field.name] = self.cls._cls_for(field)
484486

485487
self.cls_by_field = field_cls
486488

487-
def __getattr__(self, item):
488-
# Lazy init because forward reference classes may not be available at the beginning.
489-
if item == 'default_gen':
490-
self.init_default_gen()
491-
return self.default_gen
492-
493-
if item == 'cls_by_field':
494-
self.init_cls_by_field()
495-
return self.cls_by_field
496-
497-
498-
def make_protoclass(cls):
499-
setattr(cls, "_betterproto", ProtoClassMetadata(cls))
500-
501-
502-
def protoclass(*args, **kwargs):
503-
cls = dataclasses.dataclass(*args, **kwargs)
504-
make_protoclass(cls)
505-
return cls
506-
507489

508490
class Message(ABC):
509491
"""
@@ -565,6 +547,19 @@ def __setattr__(self, attr: str, value: Any) -> None:
565547

566548
super().__setattr__(attr, value)
567549

550+
@property
551+
def _betterproto(self):
552+
"""
553+
Lazy initialize metadata for each protobuf class.
554+
It may be initialized multiple times in a multi-threaded environment,
555+
but that won't affect the correctness.
556+
"""
557+
meta = getattr(self.__class__, "_betterproto_meta", None)
558+
if not meta:
559+
meta = ProtoClassMetadata(self.__class__)
560+
self.__class__._betterproto_meta = meta
561+
return meta
562+
568563
def __bytes__(self) -> bytes:
569564
"""
570565
Get the binary encoded Protobuf representation of this instance.
@@ -930,7 +925,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
930925
return (field.name, getattr(message, field.name))
931926

932927

933-
@protoclass
928+
@dataclasses.dataclass
934929
class _Duration(Message):
935930
# Signed seconds of the span of time. Must be from -315,576,000,000 to
936931
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@@ -955,7 +950,7 @@ def delta_to_json(delta: timedelta) -> str:
955950
return ".".join(parts) + "s"
956951

957952

958-
@protoclass
953+
@dataclasses.dataclass
959954
class _Timestamp(Message):
960955
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
961956
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@@ -1005,47 +1000,47 @@ def from_dict(self: T, value: Any) -> T:
10051000
return self
10061001

10071002

1008-
@protoclass
1003+
@dataclasses.dataclass
10091004
class _BoolValue(_WrappedMessage):
10101005
value: bool = bool_field(1)
10111006

10121007

1013-
@protoclass
1008+
@dataclasses.dataclass
10141009
class _Int32Value(_WrappedMessage):
10151010
value: int = int32_field(1)
10161011

10171012

1018-
@protoclass
1013+
@dataclasses.dataclass
10191014
class _UInt32Value(_WrappedMessage):
10201015
value: int = uint32_field(1)
10211016

10221017

1023-
@protoclass
1018+
@dataclasses.dataclass
10241019
class _Int64Value(_WrappedMessage):
10251020
value: int = int64_field(1)
10261021

10271022

1028-
@protoclass
1023+
@dataclasses.dataclass
10291024
class _UInt64Value(_WrappedMessage):
10301025
value: int = uint64_field(1)
10311026

10321027

1033-
@protoclass
1028+
@dataclasses.dataclass
10341029
class _FloatValue(_WrappedMessage):
10351030
value: float = float_field(1)
10361031

10371032

1038-
@protoclass
1033+
@dataclasses.dataclass
10391034
class _DoubleValue(_WrappedMessage):
10401035
value: float = double_field(1)
10411036

10421037

1043-
@protoclass
1038+
@dataclasses.dataclass
10441039
class _StringValue(_WrappedMessage):
10451040
value: str = string_field(1)
10461041

10471042

1048-
@protoclass
1043+
@dataclasses.dataclass
10491044
class _BytesValue(_WrappedMessage):
10501045
value: bytes = bytes_field(1)
10511046

betterproto/templates/template.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/test_features.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
def test_has_field():
7-
@betterproto.protoclass
7+
@dataclass
88
class Bar(betterproto.Message):
99
baz: int = betterproto.int32_field(1)
1010

11-
@betterproto.protoclass
11+
@dataclass
1212
class Foo(betterproto.Message):
1313
bar: Bar = betterproto.message_field(1)
1414

@@ -34,11 +34,11 @@ class Foo(betterproto.Message):
3434

3535

3636
def test_class_init():
37-
@betterproto.protoclass
37+
@dataclass
3838
class Bar(betterproto.Message):
3939
name: str = betterproto.string_field(1)
4040

41-
@betterproto.protoclass
41+
@dataclass
4242
class Foo(betterproto.Message):
4343
name: str = betterproto.string_field(1)
4444
child: Bar = betterproto.message_field(2)
@@ -53,7 +53,7 @@ class TestEnum(betterproto.Enum):
5353
ZERO = 0
5454
ONE = 1
5555

56-
@betterproto.protoclass
56+
@dataclass
5757
class Foo(betterproto.Message):
5858
bar: TestEnum = betterproto.enum_field(1)
5959

@@ -67,13 +67,13 @@ class Foo(betterproto.Message):
6767

6868

6969
def test_unknown_fields():
70-
@betterproto.protoclass
70+
@dataclass
7171
class Newer(betterproto.Message):
7272
foo: bool = betterproto.bool_field(1)
7373
bar: int = betterproto.int32_field(2)
7474
baz: str = betterproto.string_field(3)
7575

76-
@betterproto.protoclass
76+
@dataclass
7777
class Older(betterproto.Message):
7878
foo: bool = betterproto.bool_field(1)
7979

@@ -89,11 +89,11 @@ class Older(betterproto.Message):
8989

9090

9191
def test_oneof_support():
92-
@betterproto.protoclass
92+
@dataclass
9393
class Sub(betterproto.Message):
9494
val: int = betterproto.int32_field(1)
9595

96-
@betterproto.protoclass
96+
@dataclass
9797
class Foo(betterproto.Message):
9898
bar: int = betterproto.int32_field(1, group="group1")
9999
baz: str = betterproto.string_field(2, group="group1")
@@ -134,7 +134,7 @@ class Foo(betterproto.Message):
134134

135135

136136
def test_json_casing():
137-
@betterproto.protoclass
137+
@dataclass
138138
class CasingTest(betterproto.Message):
139139
pascal_case: int = betterproto.int32_field(1)
140140
camel_case: int = betterproto.int32_field(2)
@@ -165,7 +165,7 @@ class CasingTest(betterproto.Message):
165165

166166

167167
def test_optional_flag():
168-
@betterproto.protoclass
168+
@dataclass
169169
class Request(betterproto.Message):
170170
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
171171

@@ -180,7 +180,7 @@ class Request(betterproto.Message):
180180

181181

182182
def test_to_dict_default_values():
183-
@betterproto.protoclass
183+
@dataclass
184184
class TestMessage(betterproto.Message):
185185
some_int: int = betterproto.int32_field(1)
186186
some_double: float = betterproto.double_field(2)
@@ -210,7 +210,7 @@ class TestMessage(betterproto.Message):
210210
}
211211

212212
# Some default and some other values
213-
@betterproto.protoclass
213+
@dataclass
214214
class TestMessage2(betterproto.Message):
215215
some_int: int = betterproto.int32_field(1)
216216
some_double: float = betterproto.double_field(2)
@@ -246,11 +246,11 @@ class TestMessage2(betterproto.Message):
246246
}
247247

248248
# Nested messages
249-
@betterproto.protoclass
249+
@dataclass
250250
class TestChildMessage(betterproto.Message):
251251
some_other_int: int = betterproto.int32_field(1)
252252

253-
@betterproto.protoclass
253+
@dataclass
254254
class TestParentMessage(betterproto.Message):
255255
some_int: int = betterproto.int32_field(1)
256256
some_double: float = betterproto.double_field(2)

0 commit comments

Comments
 (0)