Skip to content

Commit 917de09

Browse files
committed
Replace extra decorator with property and lazy initialization so that it is backward compatible.
1 parent 1f7f390 commit 917de09

File tree

3 files changed

+46
-50
lines changed

3 files changed

+46
-50
lines changed

betterproto/__init__.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,9 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
433433

434434

435435
class ProtoClassMetadata:
436-
cls: "Message"
436+
cls: Type["Message"]
437437

438-
def __init__(self, cls: "Message"):
438+
def __init__(self, cls: Type["Message"]):
439439
self.cls = cls
440440
by_field = {}
441441
by_group = {}
@@ -452,6 +452,9 @@ def __init__(self, cls: "Message"):
452452
self.oneof_group_by_field = by_field
453453
self.oneof_field_by_group = by_group
454454

455+
self.init_default_gen()
456+
self.init_cls_by_field()
457+
455458
def init_default_gen(self):
456459
default_gen = {}
457460

@@ -478,34 +481,13 @@ def init_cls_by_field(self):
478481
],
479482
bases=(Message,),
480483
)
481-
make_protoclass(Entry)
482484
field_cls[field.name] = Entry
483485
field_cls[field.name + ".value"] = vt
484486
else:
485487
field_cls[field.name] = self.cls._cls_for(field)
486488

487489
self.cls_by_field = field_cls
488490

489-
def __getattr__(self, item):
490-
# Lazy init because forward reference classes may not be available at the beginning.
491-
if item == 'default_gen':
492-
self.init_default_gen()
493-
return self.default_gen
494-
495-
if item == 'cls_by_field':
496-
self.init_cls_by_field()
497-
return self.cls_by_field
498-
499-
500-
def make_protoclass(cls):
501-
setattr(cls, "_betterproto", ProtoClassMetadata(cls))
502-
503-
504-
def protoclass(*args, **kwargs):
505-
cls = dataclasses.dataclass(*args, **kwargs)
506-
make_protoclass(cls)
507-
return cls
508-
509491

510492
class Message(ABC):
511493
"""
@@ -567,6 +549,19 @@ def __setattr__(self, attr: str, value: Any) -> None:
567549

568550
super().__setattr__(attr, value)
569551

552+
@property
553+
def _betterproto(self):
554+
"""
555+
Lazy initialize metadata for each protobuf class.
556+
It may be initialized multiple times in a multi-threaded environment,
557+
but that won't affect the correctness.
558+
"""
559+
meta = getattr(self.__class__, "_betterproto_meta", None)
560+
if not meta:
561+
meta = ProtoClassMetadata(self.__class__)
562+
self.__class__._betterproto_meta = meta
563+
return meta
564+
570565
def __bytes__(self) -> bytes:
571566
"""
572567
Get the binary encoded Protobuf representation of this instance.
@@ -932,7 +927,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
932927
return (field.name, getattr(message, field.name))
933928

934929

935-
@protoclass
930+
@dataclasses.dataclass
936931
class _Duration(Message):
937932
# Signed seconds of the span of time. Must be from -315,576,000,000 to
938933
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@@ -957,7 +952,7 @@ def delta_to_json(delta: timedelta) -> str:
957952
return ".".join(parts) + "s"
958953

959954

960-
@protoclass
955+
@dataclasses.dataclass
961956
class _Timestamp(Message):
962957
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
963958
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@@ -1007,47 +1002,47 @@ def from_dict(self: T, value: Any) -> T:
10071002
return self
10081003

10091004

1010-
@protoclass
1005+
@dataclasses.dataclass
10111006
class _BoolValue(_WrappedMessage):
10121007
value: bool = bool_field(1)
10131008

10141009

1015-
@protoclass
1010+
@dataclasses.dataclass
10161011
class _Int32Value(_WrappedMessage):
10171012
value: int = int32_field(1)
10181013

10191014

1020-
@protoclass
1015+
@dataclasses.dataclass
10211016
class _UInt32Value(_WrappedMessage):
10221017
value: int = uint32_field(1)
10231018

10241019

1025-
@protoclass
1020+
@dataclasses.dataclass
10261021
class _Int64Value(_WrappedMessage):
10271022
value: int = int64_field(1)
10281023

10291024

1030-
@protoclass
1025+
@dataclasses.dataclass
10311026
class _UInt64Value(_WrappedMessage):
10321027
value: int = uint64_field(1)
10331028

10341029

1035-
@protoclass
1030+
@dataclasses.dataclass
10361031
class _FloatValue(_WrappedMessage):
10371032
value: float = float_field(1)
10381033

10391034

1040-
@protoclass
1035+
@dataclasses.dataclass
10411036
class _DoubleValue(_WrappedMessage):
10421037
value: float = double_field(1)
10431038

10441039

1045-
@protoclass
1040+
@dataclasses.dataclass
10461041
class _StringValue(_WrappedMessage):
10471042
value: str = string_field(1)
10481043

10491044

1050-
@protoclass
1045+
@dataclasses.dataclass
10511046
class _BytesValue(_WrappedMessage):
10521047
value: bytes = bytes_field(1)
10531048

betterproto/templates/template.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/test_features.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
def test_has_field():
7-
@betterproto.protoclass
7+
@dataclass
88
class Bar(betterproto.Message):
99
baz: int = betterproto.int32_field(1)
1010

11-
@betterproto.protoclass
11+
@dataclass
1212
class Foo(betterproto.Message):
1313
bar: Bar = betterproto.message_field(1)
1414

@@ -34,11 +34,11 @@ class Foo(betterproto.Message):
3434

3535

3636
def test_class_init():
37-
@betterproto.protoclass
37+
@dataclass
3838
class Bar(betterproto.Message):
3939
name: str = betterproto.string_field(1)
4040

41-
@betterproto.protoclass
41+
@dataclass
4242
class Foo(betterproto.Message):
4343
name: str = betterproto.string_field(1)
4444
child: Bar = betterproto.message_field(2)
@@ -53,7 +53,7 @@ class TestEnum(betterproto.Enum):
5353
ZERO = 0
5454
ONE = 1
5555

56-
@betterproto.protoclass
56+
@dataclass
5757
class Foo(betterproto.Message):
5858
bar: TestEnum = betterproto.enum_field(1)
5959

@@ -67,13 +67,13 @@ class Foo(betterproto.Message):
6767

6868

6969
def test_unknown_fields():
70-
@betterproto.protoclass
70+
@dataclass
7171
class Newer(betterproto.Message):
7272
foo: bool = betterproto.bool_field(1)
7373
bar: int = betterproto.int32_field(2)
7474
baz: str = betterproto.string_field(3)
7575

76-
@betterproto.protoclass
76+
@dataclass
7777
class Older(betterproto.Message):
7878
foo: bool = betterproto.bool_field(1)
7979

@@ -89,11 +89,11 @@ class Older(betterproto.Message):
8989

9090

9191
def test_oneof_support():
92-
@betterproto.protoclass
92+
@dataclass
9393
class Sub(betterproto.Message):
9494
val: int = betterproto.int32_field(1)
9595

96-
@betterproto.protoclass
96+
@dataclass
9797
class Foo(betterproto.Message):
9898
bar: int = betterproto.int32_field(1, group="group1")
9999
baz: str = betterproto.string_field(2, group="group1")
@@ -134,7 +134,7 @@ class Foo(betterproto.Message):
134134

135135

136136
def test_json_casing():
137-
@betterproto.protoclass
137+
@dataclass
138138
class CasingTest(betterproto.Message):
139139
pascal_case: int = betterproto.int32_field(1)
140140
camel_case: int = betterproto.int32_field(2)
@@ -165,7 +165,7 @@ class CasingTest(betterproto.Message):
165165

166166

167167
def test_optional_flag():
168-
@betterproto.protoclass
168+
@dataclass
169169
class Request(betterproto.Message):
170170
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
171171

@@ -180,7 +180,7 @@ class Request(betterproto.Message):
180180

181181

182182
def test_to_dict_default_values():
183-
@betterproto.protoclass
183+
@dataclass
184184
class TestMessage(betterproto.Message):
185185
some_int: int = betterproto.int32_field(1)
186186
some_double: float = betterproto.double_field(2)
@@ -210,7 +210,7 @@ class TestMessage(betterproto.Message):
210210
}
211211

212212
# Some default and some other values
213-
@betterproto.protoclass
213+
@dataclass
214214
class TestMessage2(betterproto.Message):
215215
some_int: int = betterproto.int32_field(1)
216216
some_double: float = betterproto.double_field(2)
@@ -246,11 +246,11 @@ class TestMessage2(betterproto.Message):
246246
}
247247

248248
# Nested messages
249-
@betterproto.protoclass
249+
@dataclass
250250
class TestChildMessage(betterproto.Message):
251251
some_other_int: int = betterproto.int32_field(1)
252252

253-
@betterproto.protoclass
253+
@dataclass
254254
class TestParentMessage(betterproto.Message):
255255
some_int: int = betterproto.int32_field(1)
256256
some_double: float = betterproto.double_field(2)

0 commit comments

Comments
 (0)