Skip to content

Commit

Permalink
Fix serialization of repeated fields with empty messages
Browse files Browse the repository at this point in the history
Extend test config and utils to support exclusion of certain json samples from
testing for symetry.
  • Loading branch information
nat-n committed Apr 5, 2021
1 parent 7c5ee47 commit 188b186
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 25 deletions.
14 changes: 12 additions & 2 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def encode_varint(value: int) -> bytes:

def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""

if proto_type in [
TYPE_ENUM,
TYPE_BOOL,
Expand Down Expand Up @@ -738,9 +739,18 @@ def __bytes__(self) -> bytes:
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
output += (
_serialize_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or b"\n\x00"
)

elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types
Expand Down
7 changes: 7 additions & 0 deletions tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@
"example_service",
"empty_service",
}


# Indicate json sample messages to skip when testing that json (de)serialization
# is symmetrical becuase some cases legitimately are not symmetrical.
# Each key references the name of the test scenario and the values in the tuple
# Are the names of the json files.
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}
3 changes: 3 additions & 0 deletions tests/inputs/empty_repeated/empty_repeated.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"msg": [{"values":[]}]
}
9 changes: 9 additions & 0 deletions tests/inputs/empty_repeated/empty_repeated.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

message MessageA {
repeated float values = 1;
}

message Test {
repeated MessageA msg = 1;
}
4 changes: 2 additions & 2 deletions tests/inputs/oneof/test_oneof.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

def test_which_count():
message = Test()
message.from_json(get_test_case_json_data("oneof")[0])
message.from_json(get_test_case_json_data("oneof")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)


def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
6 changes: 3 additions & 3 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value():
"""
message = Test()
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)

assert message.move == Move(
Expand All @@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
"""
message = Test()
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
Expand All @@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():

def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0])
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
28 changes: 18 additions & 10 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from collections import namedtuple
from types import ModuleType
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple

import pytest

Expand All @@ -29,7 +29,12 @@


class TestCases:
def __init__(self, path, services: Set[str], xfail: Set[str]):
def __init__(
self,
path,
services: Set[str],
xfail: Set[str],
):
_all = set(get_directories(path)) - {"__pycache__"}
_services = services
_messages = (_all - services) - {"__pycache__"}
Expand Down Expand Up @@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data

for _ in range(repeat):
for json_sample in json_data:
for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue

message: betterproto.Message = plugin_module.Test()

message.from_json(json_sample)
message.from_json(sample.json)
message_json = message.to_json(0)

assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(json_sample)
)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(sample.json)
)


@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
Expand All @@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data

for json_sample in json_data:
reference_instance = Parse(json_sample, reference_module().Test())
for sample in json_data:
reference_instance = Parse(sample.json, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()

for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(json_sample)
plugin_module.Test().from_json(sample.json)
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
Expand Down
33 changes: 25 additions & 8 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass
import importlib
import os
import pathlib
import sys
from pathlib import Path
import sys
from types import ModuleType
from typing import Callable, Generator, List, Optional, Union
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

Expand Down Expand Up @@ -47,11 +47,24 @@ async def protoc(
return stdout, stderr, proc.returncode


def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
@dataclass
class TestCaseJsonFile:
json: str
test_name: str
file_name: str

def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())


def get_test_case_json_data(
test_case_name: str, *json_file_names: str
) -> List[TestCaseJsonFile]:
"""
:return:
A list of all files found in "inputs_path/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
A list of all files found in "{inputs_path}/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
Expand All @@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(fh.read())
result.append(
TestCaseJsonFile(
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
)
)

return result

Expand All @@ -86,7 +103,7 @@ def find_module(
if predicate(module):
return module

module_path = pathlib.Path(*module.__path__)
module_path = Path(*module.__path__)

for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path:
Expand Down

0 comments on commit 188b186

Please sign in to comment.