Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,28 @@ def __repr__(self) -> str:
def __getattribute__(self, name: str) -> Any:
"""
Lazily initialize default values to avoid infinite recursion for recursive
message types
message types.
Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
"""
try:
group_current = super().__getattribute__("_group_current")
except AttributeError:
pass
else:
if name not in {"__class__", "_betterproto"}:
group = self._betterproto.oneof_group_by_field.get(name)
if group is not None and group_current[group] != name:
if sys.version_info < (3, 10):
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}"
)
else:
raise AttributeError(
f"{group!r} is set to {group_current[group]!r}, not {name!r}",
name=name,
obj=self,
)

value = super().__getattribute__(name)
if value is not PLACEHOLDER:
return value
Expand Down Expand Up @@ -761,7 +781,10 @@ def __bytes__(self) -> bytes:
"""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
continue

if value is None:
# Optional items should be skipped. This is used for the Google
Expand All @@ -775,9 +798,7 @@ def __bytes__(self) -> bytes:
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = (
meta.group and self._group_current[meta.group] == field_name
)
selected_in_group = bool(meta.group)

# Empty messages can still be sent on the wire if they were
# set (or received empty).
Expand Down Expand Up @@ -1016,7 +1037,12 @@ def parse(self: T, data: bytes) -> T:
parsed.wire_type, meta, field_name, parsed.value
)

current = getattr(self, field_name)
try:
current = getattr(self, field_name)
except AttributeError:
current = self._get_field_default(field_name)
setattr(self, field_name, current)

if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
Expand Down Expand Up @@ -1077,7 +1103,10 @@ def to_dict(
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
value = self._get_field_default(field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
Expand Down Expand Up @@ -1209,7 +1238,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:

if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
Expand Down Expand Up @@ -1486,7 +1515,6 @@ def _validate_field_groups(cls, values):
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore

for group, field_set in group_to_one_ofs.items():

if len(field_set) == 1:
(field,) = field_set
field_name = field.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():

# None of these fields were explicitly set BUT they should not actually be null
# themselves
assert isinstance(message.foo, Foo)
assert isinstance(message2.foo, Foo)
assert not hasattr(message, "foo")
assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
assert not hasattr(message2, "foo")
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER

assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo)
Expand Down
13 changes: 6 additions & 7 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)

assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)

Expand All @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)

Expand All @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert not hasattr(message, "signal")
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
9 changes: 5 additions & 4 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,18 @@ class Foo(betterproto.Message):
foo.baz = "test"

# Other oneof fields should now be unset
assert foo.bar == 0
assert not hasattr(foo, "bar")
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group1")[0] == "baz"

foo.sub.val = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a behaviour change?

Copy link
Contributor Author

@a-khabarov a-khabarov Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR makes it impossible to use foo.sub.val = 1 when foo.sub is unset in the group.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit of a tradeoff. With this change the users of betterproto can no longer use the foo.sub.val = 1 syntax for fields that are unset in groups, but this also means that there is less risk of them accidentally changing which field is set in a group.

foo.sub = Sub(val=1)
assert betterproto.serialized_on_wire(foo.sub)

foo.abc = "test"

# Group 1 shouldn't be touched, group 2 should have reset
assert foo.sub.val == 0
assert betterproto.serialized_on_wire(foo.sub) is False
assert not hasattr(foo, "sub")
assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group2")[0] == "abc"

# Zero value should always serialize for one-of
Expand Down