Skip to content

Fix oneof serialization with proto3 field presence #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class FieldMetadata:
group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
# Is the field optional
optional: Optional[bool] = False

@staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata":
Expand All @@ -165,7 +167,9 @@ def dataclass_field(
return dataclasses.field(
default=None if optional else PLACEHOLDER,
metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
"betterproto": FieldMetadata(
number, proto_type, map_types, group, wraps, optional
)
},
)

Expand Down Expand Up @@ -620,7 +624,8 @@ def __post_init__(self) -> None:
if meta.group:
group_current.setdefault(meta.group)

if self.__raw_get(field_name) != PLACEHOLDER:
value = self.__raw_get(field_name)
if value != PLACEHOLDER and not (meta.optional and value is None):
# Found a non-sentinel value
all_sentinel = False

Expand Down Expand Up @@ -1043,7 +1048,6 @@ def to_dict(
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
field_is_optional = defaults[field_name] is type(None)
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
Expand Down Expand Up @@ -1082,7 +1086,8 @@ def to_dict(
if value or include_default_values:
output[cased_name] = value
elif value is None:
output[cased_name] = None
if include_default_values:
output[cased_name] = value
elif (
value._serialized_on_wire
or include_default_values
Expand All @@ -1109,16 +1114,17 @@ def to_dict(
if field_is_repeated:
output[cased_name] = [str(n) for n in value]
elif value is None:
output[cased_name] = value
if include_default_values:
output[cased_name] = value
else:
output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES:
if field_is_repeated:
output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
elif value is None:
output[cased_name] = None
elif value is None and include_default_values:
output[cased_name] = value
else:
output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
Expand All @@ -1132,8 +1138,9 @@ def to_dict(
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
elif value is None:
output[cased_name] = None
elif field_is_optional:
if include_default_values:
output[cased_name] = value
elif meta.optional:
enum_class = field_types[field_name].__args__[0]
output[cased_name] = enum_class(value).name
else:
Expand Down Expand Up @@ -1173,9 +1180,6 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
if value[key] is None and self._get_field_default(key) == None:
# Setting an optional value to None.
setattr(self, field_name, None)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name]
if cls == datetime:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
{
"test1": null,
"test2": null,
"test3": null,
"test4": null,
"test5": null,
"test6": null,
"test7": null,
"test8": null
}
{}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"test2": false,
"test3": "",
"test4": "",
"test5": null,
"test6": "A",
"test7": "0",
"test8": 0
Expand Down
38 changes: 38 additions & 0 deletions tests/inputs/proto3_field_presence/test_proto3_field_presence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json

from tests.output_betterproto.proto3_field_presence import Test, InnerTest, TestEnum


def test_null_fields_json():
"""Ensure that using "null" in JSON is equivalent to not specifying a
field, for fields with explicit presence"""

def test_json(ref_json: str, obj_json: str) -> None:
"""`ref_json` and `obj_json` are JSON strings describing a `Test` object.
Test that deserializing both leads to the same object, and that
`ref_json` is the normalized format."""
ref_obj = Test().from_json(ref_json)
obj = Test().from_json(obj_json)

assert obj == ref_obj
assert json.loads(obj.to_json(0)) == json.loads(ref_json)

test_json("{}", '{ "test1": null, "test2": null, "test3": null }')
test_json("{}", '{ "test4": null, "test5": null, "test6": null }')
test_json("{}", '{ "test7": null, "test8": null }')
test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }')

# Make sure that if include_default_values is set, None values are
# exported.
obj = Test()
assert obj.to_dict() == {}
assert obj.to_dict(include_default_values=True) == {
"test1": None,
"test2": None,
"test3": None,
"test4": None,
"test5": None,
"test6": None,
"test7": None,
"test8": None,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"nested": {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
syntax = "proto3";

message Test {
oneof kind {
Nested nested = 1;
WithOptional with_optional = 2;
}
}

message InnerNested {
optional bool a = 1;
}

message Nested {
InnerNested inner = 1;
}

message WithOptional {
optional bool b = 2;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from tests.output_betterproto.proto3_field_presence_oneof import (
Test,
InnerNested,
Nested,
WithOptional,
)


def test_serialization():
"""Ensure that serialization of fields unset but with explicit field
presence do not bloat the serialized payload with length-delimited fields
with length 0"""

def test_empty_nested(message: Test) -> None:
# '0a' => tag 1, length delimited
# '00' => length: 0
assert bytes(message) == bytearray.fromhex("0a 00")

test_empty_nested(Test(nested=Nested()))
test_empty_nested(Test(nested=Nested(inner=None)))
test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None))))

def test_empty_with_optional(message: Test) -> None:
# '12' => tag 2, length delimited
# '00' => length: 0
assert bytes(message) == bytearray.fromhex("12 00")

test_empty_with_optional(Test(with_optional=WithOptional()))
test_empty_with_optional(Test(with_optional=WithOptional(b=None)))