diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f6f2d1a95..91d1420e9 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -307,6 +307,7 @@ def encode_varint(value: int) -> bytes: def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: """Adjusts values before serialization.""" + if proto_type in [ TYPE_ENUM, TYPE_BOOL, @@ -738,9 +739,18 @@ def __bytes__(self) -> bytes: output += _serialize_single(meta.number, TYPE_BYTES, buf) else: for item in value: - output += _serialize_single( - meta.number, meta.proto_type, item, wraps=meta.wraps or "" + output += ( + _serialize_single( + meta.number, + meta.proto_type, + item, + wraps=meta.wraps or "", + ) + # if it's an empty message it still needs to be represented + # as an item in the repeated list + or b"\n\x00" ) + elif isinstance(value, dict): for k, v in value.items(): assert meta.map_types diff --git a/tests/inputs/config.py b/tests/inputs/config.py index f95aad207..f82139916 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -21,3 +21,10 @@ "example_service", "empty_service", } + + +# Indicate json sample messages to skip when testing that json (de)serialization +# is symmetrical becuase some cases legitimately are not symmetrical. +# Each key references the name of the test scenario and the values in the tuple +# Are the names of the json files. +non_symmetrical_json = {"empty_repeated": ("empty_repeated",)} diff --git a/tests/inputs/empty_repeated/empty_repeated.json b/tests/inputs/empty_repeated/empty_repeated.json new file mode 100644 index 000000000..12a801c6f --- /dev/null +++ b/tests/inputs/empty_repeated/empty_repeated.json @@ -0,0 +1,3 @@ +{ + "msg": [{"values":[]}] +} diff --git a/tests/inputs/empty_repeated/empty_repeated.proto b/tests/inputs/empty_repeated/empty_repeated.proto new file mode 100644 index 000000000..3be831ace --- /dev/null +++ b/tests/inputs/empty_repeated/empty_repeated.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +message MessageA { + repeated float values = 1; +} + +message Test { + repeated MessageA msg = 1; +} diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py index ac9af9eb8..d1267659f 100644 --- a/tests/inputs/oneof/test_oneof.py +++ b/tests/inputs/oneof/test_oneof.py @@ -5,11 +5,11 @@ def test_which_count(): message = Test() - message.from_json(get_test_case_json_data("oneof")[0]) + message.from_json(get_test_case_json_data("oneof")[0].json) assert betterproto.which_one_of(message, "foo") == ("pitied", 100) def test_which_name(): message = Test() - message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0]) + message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 73b37c6e3..7e287d4a4 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value(): """ message = Test() message.from_json( - get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0] + get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json ) assert message.move == Move( @@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): """ message = Test() message.from_json( - get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0] + get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json ) assert message.move == Move( x=0, y=0 @@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): def test_which_one_of_returns_second_field_when_set(): message = Test() - message.from_json(get_test_case_json_data("oneof_enum")[0]) + message.from_json(get_test_case_json_data("oneof_enum")[0].json) assert message.move == Move(x=2, y=3) assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 6d6907c61..dbcf1975a 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -5,7 +5,7 @@ import sys from collections import namedtuple from types import ModuleType -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Set, Tuple import pytest @@ -29,7 +29,12 @@ class TestCases: - def __init__(self, path, services: Set[str], xfail: Set[str]): + def __init__( + self, + path, + services: Set[str], + xfail: Set[str], + ): _all = set(get_directories(path)) - {"__pycache__"} _services = services _messages = (_all - services) - {"__pycache__"} @@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None: plugin_module, _, json_data = test_data for _ in range(repeat): - for json_sample in json_data: + for sample in json_data: + if sample.belongs_to(test_input_config.non_symmetrical_json): + continue + message: betterproto.Message = plugin_module.Test() - message.from_json(json_sample) + message.from_json(sample.json) message_json = message.to_json(0) - assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( - json.loads(json_sample) - ) + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( + json.loads(sample.json) + ) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None: def test_binary_compatibility(repeat, test_data: TestData) -> None: plugin_module, reference_module, json_data = test_data - for json_sample in json_data: - reference_instance = Parse(json_sample, reference_module().Test()) + for sample in json_data: + reference_instance = Parse(sample.json, reference_module().Test()) reference_binary_output = reference_instance.SerializeToString() for _ in range(repeat): plugin_instance_from_json: betterproto.Message = ( - plugin_module.Test().from_json(json_sample) + plugin_module.Test().from_json(sample.json) ) plugin_instance_from_binary = plugin_module.Test.FromString( reference_binary_output diff --git a/tests/util.py b/tests/util.py index 5dcf15552..950cf7af7 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,11 +1,11 @@ import asyncio +from dataclasses import dataclass import importlib import os -import pathlib -import sys from pathlib import Path +import sys from types import ModuleType -from typing import Callable, Generator, List, Optional, Union +from typing import Callable, Dict, Generator, List, Optional, Tuple, Union os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -47,11 +47,24 @@ async def protoc( return stdout, stderr, proc.returncode -def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]: +@dataclass +class TestCaseJsonFile: + json: str + test_name: str + file_name: str + + def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]): + return self.file_name in non_symmetrical_json.get(self.test_name, tuple()) + + +def get_test_case_json_data( + test_case_name: str, *json_file_names: str +) -> List[TestCaseJsonFile]: """ :return: - A list of all files found in "inputs_path/test_case_name" with names matching - f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names + A list of all files found in "{inputs_path}/test_case_name" with names matching + f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by + json_file_names """ test_case_dir = inputs_path.joinpath(test_case_name) possible_file_paths = [ @@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[ if not test_data_file_path.exists(): continue with test_data_file_path.open("r") as fh: - result.append(fh.read()) + result.append( + TestCaseJsonFile( + fh.read(), test_case_name, test_data_file_path.name.split(".")[0] + ) + ) return result @@ -86,7 +103,7 @@ def find_module( if predicate(module): return module - module_path = pathlib.Path(*module.__path__) + module_path = Path(*module.__path__) for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]: if sub == module_path: