Skip to content

Commit ef00e72

Browse files
committed
Replace extra decorator with property and lazy initialization so that it is backward compatible.
1 parent c0b05af commit ef00e72

File tree

3 files changed

+43
-40
lines changed

3 files changed

+43
-40
lines changed

betterproto/__init__.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
431431

432432

433433
class ProtoClassMetadata:
434-
cls: "Message"
434+
cls: Type["Message"]
435435

436-
def __init__(self, cls: "Message"):
436+
def __init__(self, cls: Type["Message"]):
437437
self.cls = cls
438438
by_field = {}
439439
by_group = {}
@@ -476,7 +476,6 @@ def init_cls_by_field(self):
476476
],
477477
bases=(Message,),
478478
)
479-
make_protoclass(Entry)
480479
field_cls[field.name] = Entry
481480
field_cls[field.name + ".value"] = vt
482481
else:
@@ -495,16 +494,6 @@ def __getattr__(self, item):
495494
return self.cls_by_field
496495

497496

498-
def make_protoclass(cls):
499-
setattr(cls, "_betterproto", ProtoClassMetadata(cls))
500-
501-
502-
def protoclass(*args, **kwargs):
503-
cls = dataclasses.dataclass(*args, **kwargs)
504-
make_protoclass(cls)
505-
return cls
506-
507-
508497
class Message(ABC):
509498
"""
510499
A protobuf message base class. Generated code will inherit from this and
@@ -565,6 +554,19 @@ def __setattr__(self, attr: str, value: Any) -> None:
565554

566555
super().__setattr__(attr, value)
567556

557+
@property
558+
def _betterproto(self):
559+
"""
560+
Lazy initialize metadata for each protobuf class.
561+
It may be initialized multiple times in a multi-threaded environment,
562+
but that won't affect the correctness.
563+
"""
564+
meta = getattr(self.__class__, "_betterproto_meta", None)
565+
if not meta:
566+
meta = ProtoClassMetadata(self.__class__)
567+
self.__class__._betterproto_meta = meta
568+
return meta
569+
568570
def __bytes__(self) -> bytes:
569571
"""
570572
Get the binary encoded Protobuf representation of this instance.
@@ -930,7 +932,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
930932
return (field.name, getattr(message, field.name))
931933

932934

933-
@protoclass
935+
@dataclasses.dataclass
934936
class _Duration(Message):
935937
# Signed seconds of the span of time. Must be from -315,576,000,000 to
936938
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@@ -955,7 +957,7 @@ def delta_to_json(delta: timedelta) -> str:
955957
return ".".join(parts) + "s"
956958

957959

958-
@protoclass
960+
@dataclasses.dataclass
959961
class _Timestamp(Message):
960962
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
961963
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@@ -1005,47 +1007,47 @@ def from_dict(self: T, value: Any) -> T:
10051007
return self
10061008

10071009

1008-
@protoclass
1010+
@dataclasses.dataclass
10091011
class _BoolValue(_WrappedMessage):
10101012
value: bool = bool_field(1)
10111013

10121014

1013-
@protoclass
1015+
@dataclasses.dataclass
10141016
class _Int32Value(_WrappedMessage):
10151017
value: int = int32_field(1)
10161018

10171019

1018-
@protoclass
1020+
@dataclasses.dataclass
10191021
class _UInt32Value(_WrappedMessage):
10201022
value: int = uint32_field(1)
10211023

10221024

1023-
@protoclass
1025+
@dataclasses.dataclass
10241026
class _Int64Value(_WrappedMessage):
10251027
value: int = int64_field(1)
10261028

10271029

1028-
@protoclass
1030+
@dataclasses.dataclass
10291031
class _UInt64Value(_WrappedMessage):
10301032
value: int = uint64_field(1)
10311033

10321034

1033-
@protoclass
1035+
@dataclasses.dataclass
10341036
class _FloatValue(_WrappedMessage):
10351037
value: float = float_field(1)
10361038

10371039

1038-
@protoclass
1040+
@dataclasses.dataclass
10391041
class _DoubleValue(_WrappedMessage):
10401042
value: float = double_field(1)
10411043

10421044

1043-
@protoclass
1045+
@dataclasses.dataclass
10441046
class _StringValue(_WrappedMessage):
10451047
value: str = string_field(1)
10461048

10471049

1048-
@protoclass
1050+
@dataclasses.dataclass
10491051
class _BytesValue(_WrappedMessage):
10501052
value: bytes = bytes_field(1)
10511053

betterproto/templates/template.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/test_features.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
def test_has_field():
7-
@betterproto.protoclass
7+
@dataclass
88
class Bar(betterproto.Message):
99
baz: int = betterproto.int32_field(1)
1010

11-
@betterproto.protoclass
11+
@dataclass
1212
class Foo(betterproto.Message):
1313
bar: Bar = betterproto.message_field(1)
1414

@@ -34,11 +34,11 @@ class Foo(betterproto.Message):
3434

3535

3636
def test_class_init():
37-
@betterproto.protoclass
37+
@dataclass
3838
class Bar(betterproto.Message):
3939
name: str = betterproto.string_field(1)
4040

41-
@betterproto.protoclass
41+
@dataclass
4242
class Foo(betterproto.Message):
4343
name: str = betterproto.string_field(1)
4444
child: Bar = betterproto.message_field(2)
@@ -53,7 +53,7 @@ class TestEnum(betterproto.Enum):
5353
ZERO = 0
5454
ONE = 1
5555

56-
@betterproto.protoclass
56+
@dataclass
5757
class Foo(betterproto.Message):
5858
bar: TestEnum = betterproto.enum_field(1)
5959

@@ -67,13 +67,13 @@ class Foo(betterproto.Message):
6767

6868

6969
def test_unknown_fields():
70-
@betterproto.protoclass
70+
@dataclass
7171
class Newer(betterproto.Message):
7272
foo: bool = betterproto.bool_field(1)
7373
bar: int = betterproto.int32_field(2)
7474
baz: str = betterproto.string_field(3)
7575

76-
@betterproto.protoclass
76+
@dataclass
7777
class Older(betterproto.Message):
7878
foo: bool = betterproto.bool_field(1)
7979

@@ -89,11 +89,11 @@ class Older(betterproto.Message):
8989

9090

9191
def test_oneof_support():
92-
@betterproto.protoclass
92+
@dataclass
9393
class Sub(betterproto.Message):
9494
val: int = betterproto.int32_field(1)
9595

96-
@betterproto.protoclass
96+
@dataclass
9797
class Foo(betterproto.Message):
9898
bar: int = betterproto.int32_field(1, group="group1")
9999
baz: str = betterproto.string_field(2, group="group1")
@@ -134,7 +134,7 @@ class Foo(betterproto.Message):
134134

135135

136136
def test_json_casing():
137-
@betterproto.protoclass
137+
@dataclass
138138
class CasingTest(betterproto.Message):
139139
pascal_case: int = betterproto.int32_field(1)
140140
camel_case: int = betterproto.int32_field(2)
@@ -165,7 +165,7 @@ class CasingTest(betterproto.Message):
165165

166166

167167
def test_optional_flag():
168-
@betterproto.protoclass
168+
@dataclass
169169
class Request(betterproto.Message):
170170
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
171171

@@ -180,7 +180,7 @@ class Request(betterproto.Message):
180180

181181

182182
def test_to_dict_default_values():
183-
@betterproto.protoclass
183+
@dataclass
184184
class TestMessage(betterproto.Message):
185185
some_int: int = betterproto.int32_field(1)
186186
some_double: float = betterproto.double_field(2)
@@ -210,7 +210,7 @@ class TestMessage(betterproto.Message):
210210
}
211211

212212
# Some default and some other values
213-
@betterproto.protoclass
213+
@dataclass
214214
class TestMessage2(betterproto.Message):
215215
some_int: int = betterproto.int32_field(1)
216216
some_double: float = betterproto.double_field(2)
@@ -246,11 +246,11 @@ class TestMessage2(betterproto.Message):
246246
}
247247

248248
# Nested messages
249-
@betterproto.protoclass
249+
@dataclass
250250
class TestChildMessage(betterproto.Message):
251251
some_other_int: int = betterproto.int32_field(1)
252252

253-
@betterproto.protoclass
253+
@dataclass
254254
class TestParentMessage(betterproto.Message):
255255
some_int: int = betterproto.int32_field(1)
256256
some_double: float = betterproto.double_field(2)

0 commit comments

Comments
 (0)