From ec3079ea4f46c85c474efe82da3f459aa2c3ccfa Mon Sep 17 00:00:00 2001 From: Federico Busetti <729029+febus982@users.noreply.github.com> Date: Sat, 28 Sep 2024 14:50:14 +0100 Subject: [PATCH] Nested Binary fields improvements (#4) * Make benchmarks go through the base64 validator/serializator and update bencharmks values * Improve nested Binary support --------- Signed-off-by: Federico Busetti <729029+febus982@users.noreply.github.com> --- README.md | 15 +++--- benchmark.py | 26 ++++++---- cloudevents_pydantic/events/_event.py | 41 +++++++-------- .../events/field_types/_canonic_types.py | 22 ++++++-- cloudevents_pydantic/formats/json.py | 4 +- docs/event_class.md | 2 +- .../events/test_field_types_serialization.py | 41 +++++++++++++-- tests/events/test_field_types_validation.py | 42 ++++++++++++++- tests/formats/test_json.py | 51 ++++++++++++++++--- 9 files changed, 188 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 88f97cc..f9e57e9 100644 --- a/README.md +++ b/README.md @@ -69,17 +69,16 @@ some performance issue in the official serialization using pydantic) These results come from a Macbook Pro M3 Max on python 3.12. Feel free to run the `benchmark.py` script yourself. -```shell +``` Timings for HTTP JSON deserialization: -This package: 2.5353065830422565 -Official SDK with pydantic model: 12.80780174996471 -Official SDK with http model: 11.474249749968294 +This package: 3.0855846670019673 +Official SDK with pydantic model: 15.35431600001175 +Official SDK with http model: 13.728038166998886 Timings for HTTP JSON serialization: -This package: 3.4850796660175547 -Official SDK with pydantic model: 39.037468083028216 -Official SDK with http model: 7.681282749981619 - +This package: 4.292417042001034 +Official SDK with pydantic model: 44.50933354199515 +Official SDK with http model: 8.929204874992138 ``` diff --git a/benchmark.py b/benchmark.py index 8f6ac1b..d6a86f8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -20,6 +20,7 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER = # DEALINGS IN THE SOFTWARE. = # ============================================================================== +import base64 import json from timeit import timeit @@ -36,11 +37,18 @@ from cloudevents.pydantic import ( from_json as from_json_pydantic, ) +from pydantic import Field from cloudevents_pydantic.bindings.http import HTTPHandler from cloudevents_pydantic.events import CloudEvent +from cloudevents_pydantic.events.field_types import Binary -valid_json = '{"data":null,"source":"https://example.com/event-producer","id":"b96267e2-87be-4f7a-b87c-82f64360d954","type":"com.example.string","specversion":"1.0","time":"2022-07-16T12:03:20.519216+04:00","subject":null,"datacontenttype":null,"dataschema":null}' +valid_json = '{"data_base64":"dGVzdA==","source":"https://example.com/event-producer","id":"b96267e2-87be-4f7a-b87c-82f64360d954","type":"com.example.string","specversion":"1.0","time":"2022-07-16T12:03:20.519216+04:00","subject":null,"datacontenttype":null,"dataschema":null}' +test_iterations = 1000000 + + +class BinaryEvent(CloudEvent): + data: Binary = Field(Binary, alias="data_base64") def json_deserialization(): @@ -56,19 +64,19 @@ def json_deserialization_official_sdk_cloudevent(): print("Timings for HTTP JSON deserialization:") -print("This package: " + str(timeit(json_deserialization))) +print("This package: " + str(timeit(json_deserialization, number=test_iterations))) print( "Official SDK with pydantic model: " - + str(timeit(json_deserialization_official_sdk_pydantic)) + + str(timeit(json_deserialization_official_sdk_pydantic, number=test_iterations)) ) print( "Official SDK with http model: " - + str(timeit(json_deserialization_official_sdk_cloudevent)) + + str(timeit(json_deserialization_official_sdk_cloudevent, number=test_iterations)) ) attributes = json.loads(valid_json) -data = attributes["data"] -del attributes["data"] +data = base64.b64decode(attributes["data_base64"]) +del attributes["data_base64"] event = CloudEvent(**attributes, data=data) http_handler = HTTPHandler() official_pydantic_event = PydanticOfficialCloudEvent.create( @@ -91,12 +99,12 @@ def json_serialization_official_sdk_cloudevent(): print("") print("Timings for HTTP JSON serialization:") -print("This package: " + str(timeit(json_serialization))) +print("This package: " + str(timeit(json_serialization, number=test_iterations))) print( "Official SDK with pydantic model: " - + str(timeit(json_serialization_official_sdk_pydantic)) + + str(timeit(json_serialization_official_sdk_pydantic, number=test_iterations)) ) print( "Official SDK with http model: " - + str(timeit(json_serialization_official_sdk_cloudevent)) + + str(timeit(json_serialization_official_sdk_cloudevent, number=test_iterations)) ) diff --git a/cloudevents_pydantic/events/_event.py b/cloudevents_pydantic/events/_event.py index 48832eb..4c5572d 100644 --- a/cloudevents_pydantic/events/_event.py +++ b/cloudevents_pydantic/events/_event.py @@ -22,8 +22,7 @@ # ============================================================================== import base64 import datetime -import typing -from typing import Union +from typing import Any, Dict, Optional, Union from cloudevents.pydantic.fields_docs import FIELD_DESCRIPTIONS from pydantic import ( @@ -33,14 +32,18 @@ model_serializer, model_validator, ) +from pydantic.fields import FieldInfo from pydantic_core.core_schema import ValidationInfo from ulid import ULID -from .field_types import URI, DateTime, SpecVersion, String, URIReference +from .field_types import URI, Binary, DateTime, SpecVersion, String, URIReference DEFAULT_SPECVERSION = SpecVersion.v1_0 +_binary_field_metadata = FieldInfo.from_annotation(Binary).metadata + + class CloudEvent(BaseModel): # type: ignore """ A Python-friendly CloudEvent representation backed by Pydantic-modeled fields. @@ -49,9 +52,9 @@ class CloudEvent(BaseModel): # type: ignore @classmethod def event_factory( cls, - id: typing.Optional[str] = None, - specversion: typing.Optional[SpecVersion] = None, - time: typing.Optional[Union[datetime.datetime, str]] = None, + id: Optional[str] = None, + specversion: Optional[SpecVersion] = None, + time: Optional[Union[datetime.datetime, str]] = None, **kwargs, ) -> "CloudEvent": """ @@ -74,7 +77,7 @@ def event_factory( **kwargs, ) - data: typing.Optional[typing.Any] = Field( + data: Any = Field( title=FIELD_DESCRIPTIONS["data"].get("title"), description=FIELD_DESCRIPTIONS["data"].get("description"), examples=[FIELD_DESCRIPTIONS["data"].get("example")], @@ -104,25 +107,25 @@ def event_factory( ) # Optional fields - time: typing.Optional[DateTime] = Field( + time: Optional[DateTime] = Field( title=FIELD_DESCRIPTIONS["time"].get("title"), description=FIELD_DESCRIPTIONS["time"].get("description"), examples=[FIELD_DESCRIPTIONS["time"].get("example")], default=None, ) - subject: typing.Optional[String] = Field( + subject: Optional[String] = Field( title=FIELD_DESCRIPTIONS["subject"].get("title"), description=FIELD_DESCRIPTIONS["subject"].get("description"), examples=[FIELD_DESCRIPTIONS["subject"].get("example")], default=None, ) - datacontenttype: typing.Optional[String] = Field( + datacontenttype: Optional[String] = Field( title=FIELD_DESCRIPTIONS["datacontenttype"].get("title"), description=FIELD_DESCRIPTIONS["datacontenttype"].get("description"), examples=[FIELD_DESCRIPTIONS["datacontenttype"].get("example")], default=None, ) - dataschema: typing.Optional[URI] = Field( + dataschema: Optional[URI] = Field( title=FIELD_DESCRIPTIONS["dataschema"].get("title"), description=FIELD_DESCRIPTIONS["dataschema"].get("description"), examples=[FIELD_DESCRIPTIONS["dataschema"].get("example")], @@ -154,7 +157,7 @@ def event_factory( """ @model_serializer(when_used="json") - def base64_json_serializer(self) -> typing.Dict[str, typing.Any]: + def base64_json_serializer(self) -> Dict[str, Any]: """Takes care of handling binary data serialization into `data_base64` attribute. @@ -164,20 +167,18 @@ def base64_json_serializer(self) -> typing.Dict[str, typing.Any]: data handled. """ model_dict = self.model_dump() # type: ignore - - if isinstance(self.data, (bytes, bytearray, memoryview)): - model_dict["data_base64"] = ( - base64.b64encode(self.data) - if isinstance(self.data, (bytes, bytearray, memoryview)) - else self.data - ) + if _binary_field_metadata == self.model_fields["data"].metadata: + model_dict["data_base64"] = model_dict["data"] + del model_dict["data"] + elif isinstance(model_dict["data"], (bytes, bytearray, memoryview)): + model_dict["data_base64"] = base64.b64encode(model_dict["data"]) del model_dict["data"] return model_dict @model_validator(mode="before") @classmethod - def base64_data_parser(cls, data: typing.Any, info: ValidationInfo) -> typing.Any: + def base64_json_validator(cls, data: dict, info: ValidationInfo) -> Any: """Takes care of handling binary data deserialization from `data_base64` attribute. diff --git a/cloudevents_pydantic/events/field_types/_canonic_types.py b/cloudevents_pydantic/events/field_types/_canonic_types.py index 1300129..3391770 100644 --- a/cloudevents_pydantic/events/field_types/_canonic_types.py +++ b/cloudevents_pydantic/events/field_types/_canonic_types.py @@ -23,6 +23,7 @@ import base64 from datetime import datetime from enum import Enum +from typing import Annotated, Union from urllib.parse import ParseResult, urlparse, urlunparse from annotated_types import Ge, Le @@ -31,7 +32,6 @@ PlainValidator, StringConstraints, ) -from typing_extensions import Annotated def bool_serializer(value: bool) -> str: @@ -42,6 +42,16 @@ def binary_serializer(value: bytes) -> str: return base64.b64encode(value).decode() +def binary_validator(value: Union[str, bytes, bytearray, memoryview]) -> bytes: + if isinstance(value, (bytes, bytearray, memoryview)): + return value + + if isinstance(value, str): + return base64.b64decode(value, validate=True) + + raise ValueError(f"Unsupported value type: {type(value)} - {value}") + + def url_serializer(value: ParseResult) -> str: return urlunparse(value) @@ -98,10 +108,14 @@ def generic_uri_validator(value: str) -> ParseResult: Sequence of allowable Unicode characters """ -# bytearray is coerced to bytes, memoryview is not supported -Binary = Annotated[bytes, PlainSerializer(binary_serializer)] +Binary = Annotated[ + bytes, + PlainValidator(binary_validator), + PlainSerializer(binary_serializer), +] """ -Sequence of bytes supporting base64 serialization/deserialization +Sequence of bytes that accepts both bytes and base64 encoded strings as input +and is serialized to a base64 encoded string. """ URI = Annotated[ diff --git a/cloudevents_pydantic/formats/json.py b/cloudevents_pydantic/formats/json.py index 8717657..6823d0e 100644 --- a/cloudevents_pydantic/formats/json.py +++ b/cloudevents_pydantic/formats/json.py @@ -20,10 +20,10 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER = # DEALINGS IN THE SOFTWARE. = # ============================================================================== -from typing import List, Type +from typing import List, Type, overload from pydantic import TypeAdapter -from typing_extensions import TypeVar, overload +from typing_extensions import TypeVar from ..events import CloudEvent diff --git a/docs/event_class.md b/docs/event_class.md index ecfe940..501cac9 100644 --- a/docs/event_class.md +++ b/docs/event_class.md @@ -105,7 +105,7 @@ When you create event types in your app you will want to make sure to follow the will be compliant with the [CloudEvents spec](https://github.com/cloudevents/spec/tree/main). ```python -from typing_extensions import TypedDict, Literal +from typing import TypedDict, Literal from cloudevents_pydantic.events import CloudEvent, field_types class OrderCreatedData(TypedDict): diff --git a/tests/events/test_field_types_serialization.py b/tests/events/test_field_types_serialization.py index 27587ae..4f0138c 100644 --- a/tests/events/test_field_types_serialization.py +++ b/tests/events/test_field_types_serialization.py @@ -20,6 +20,8 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER = # DEALINGS IN THE SOFTWARE. = # ============================================================================== +from typing import Union + import pytest from pydantic import BaseModel @@ -45,18 +47,47 @@ class BoolModel(BaseModel): @pytest.mark.parametrize( - ["data", "expected_value"], + ["data", "serialized_output"], [ - pytest.param("test", "dGVzdA==", id="string"), pytest.param(b"test", "dGVzdA==", id="bytes"), - pytest.param(bytearray([2, 3, 5, 7]), "AgMFBw==", id="bytearray"), + pytest.param(b"\x02\x03\x05\x07", "AgMFBw==", id="bytearray"), ], ) -def test_binary_data_is_b64encoded(data, expected_value): +def test_binary_serialization( + data: Union[bytes, str], + serialized_output: str, +): class BinaryModel(BaseModel): value: Binary - assert BinaryModel(value=data).model_dump()["value"] == expected_value + model = BinaryModel(value=data) + + serialized_value = model.model_dump()["value"] + + assert serialized_value == serialized_output + assert isinstance(serialized_value, str) + + +@pytest.mark.parametrize( + ["data", "serialized_output"], + [ + pytest.param(b"test", "dGVzdA==", id="bytes"), + pytest.param(b"\x02\x03\x05\x07", "AgMFBw==", id="bytearray"), + ], +) +def test_nested_binary_serialization( + data: Union[bytes, str], + serialized_output: str, +): + class BinaryModel(BaseModel): + value: Binary + + model = BinaryModel(value=data) + + serialized_value = model.model_dump()["value"] + + assert serialized_value == serialized_output + assert isinstance(serialized_value, str) @pytest.mark.parametrize( diff --git a/tests/events/test_field_types_validation.py b/tests/events/test_field_types_validation.py index 8830fc2..bd580cf 100644 --- a/tests/events/test_field_types_validation.py +++ b/tests/events/test_field_types_validation.py @@ -20,13 +20,16 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER = # DEALINGS IN THE SOFTWARE. = # ============================================================================== +from typing import Literal, Union + import pytest from pydantic import BaseModel, ValidationError -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from cloudevents_pydantic.events import CloudEvent from cloudevents_pydantic.events.field_types import ( URI, + Binary, Integer, String, URIReference, @@ -45,6 +48,43 @@ class IntModel(BaseModel): assert IntModel(value=2312534).value == 2312534 +@pytest.mark.parametrize( + ["data", "validated_data"], + [ + pytest.param("dGVzdA==", b"test", id="bytes_base64"), + pytest.param(b"test", b"test", id="bytes"), + pytest.param("AgMFBw==", b"\x02\x03\x05\x07", id="bytearray_base64"), + pytest.param(b"\x02\x03\x05\x07", b"\x02\x03\x05\x07", id="bytearray"), + ], +) +def test_binary_validation_accepts_strings_and_bytes( + data: Union[bytes, str], + validated_data: bytes, +): + class BinaryModel(BaseModel): + value: Binary + + model = BinaryModel(value=data) + + assert model.value == validated_data + assert isinstance(model.value, bytes) + + +@pytest.mark.parametrize( + ["data"], + [ + pytest.param(99.99, id="float"), + pytest.param("non?-/*base64-string", id="non_base_64_string"), + ], +) +def test_binary_validation_fails_on_non_strings_and_bytes(data): + class BinaryModel(BaseModel): + value: Binary + + with pytest.raises(ValidationError): + BinaryModel(value=data) + + def test_strings_allows_valid_unicode_chars(): class StrModel(BaseModel): value: String diff --git a/tests/formats/test_json.py b/tests/formats/test_json.py index 2d421e9..2aa2fed 100644 --- a/tests/formats/test_json.py +++ b/tests/formats/test_json.py @@ -23,12 +23,13 @@ import datetime import json from pathlib import Path -from typing import Any, Sequence +from typing import Any, Dict, Sequence from urllib.parse import ParseResult import pytest from jsonschema import validate from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypedDict from cloudevents_pydantic.events import CloudEvent from cloudevents_pydantic.events.field_types import Binary, SpecVersion @@ -45,7 +46,7 @@ "id": "b96267e2-87be-4f7a-b87c-82f64360d954", "specversion": "1.0", } -test_attributes = { +test_attributes: Dict[str, Any] = { "type": "com.example.string", "source": "https://example.com/event-producer", "id": "b96267e2-87be-4f7a-b87c-82f64360d954", @@ -56,10 +57,6 @@ valid_json_batch = f"[{valid_json}]" -class BinaryDataEvent(CloudEvent): - data: Binary - - with open( Path(__file__).parent.joinpath("cloudevents_jsonschema_1.0.2.json"), "r", @@ -327,6 +324,9 @@ def test_to_json_base64_with_binary_type( b64_expected: bool, batch: bool, ): + class BinaryDataEvent(CloudEvent): + data: Binary + input_attrs = test_attributes.copy() input_attrs["data"] = data event = BinaryDataEvent.event_factory(**input_attrs) @@ -360,6 +360,9 @@ def test_to_json_base64_with_binary_type( def test_from_json_base64_with_binary_type( b64_data: str, expected_value: type, batch: bool ): + class BinaryDataEvent(CloudEvent): + data: Binary + json_string = ( '{"data_base64":"' + b64_data @@ -375,3 +378,39 @@ def test_from_json_base64_with_binary_type( event = from_json(json_string, BinaryDataEvent) assert event.data == expected_value assert isinstance(event, BinaryDataEvent) + + +def test_nested_binary_fields_correctly_serialized(): + class SomeData(TypedDict): + # Using `data` to verify it doesn't become `data_base64` when nested + data: Binary + test: str + + class BinaryNestedEvent(CloudEvent): + data: SomeData + + input_attrs = test_attributes.copy() + input_attrs["data"] = { + "data": b"\x02\x03\x05\x07", + "test": "test", + } + event = BinaryNestedEvent.event_factory(**input_attrs) + assert ( + to_json(event) + == '{"data":{"data":"AgMFBw==","test":"test"},"source":"https://example.com/event-producer","id":"b96267e2-87be-4f7a-b87c-82f64360d954","type":"com.example.string","specversion":"1.0","time":"2022-07-16T12:03:20.519216+04:00","subject":null,"datacontenttype":null,"dataschema":null}' + ) + + +def test_nested_binary_fields_correctly_deserialized(): + json_input = '{"data":{"data":"AgMFBw==","test":"test"},"source":"https://example.com/event-producer","id":"b96267e2-87be-4f7a-b87c-82f64360d954","type":"com.example.string","specversion":"1.0","time":"2022-07-16T12:03:20.519216+04:00","subject":null,"datacontenttype":null,"dataschema":null}' + + class SomeData(TypedDict): + # Using `data` to verify it doesn't become `data_base64` when nested + data: Binary + test: str + + class BinaryNestedEvent(CloudEvent): + data: SomeData + + event: BinaryNestedEvent = from_json(json_input, BinaryNestedEvent) + assert event.data["data"] == b"\x02\x03\x05\x07"