Skip to content

Commit

Permalink
fix: inspect types when resolving field marshallers for structured types
Browse files Browse the repository at this point in the history
Fixes an issue where two structured types within the same graph have conflicting attribute names, resulting in incorrect marshaller assignment for those fields.
  • Loading branch information
seandstewart committed Oct 16, 2024
1 parent 6cb9b75 commit 78d4896
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 26 deletions.
13 changes: 11 additions & 2 deletions src/typelib/marshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid

from typelib import graph, serdes
from typelib.py import compat, inspection
from typelib.py import compat, inspection, refs

T = tp.TypeVar("T")

Expand Down Expand Up @@ -452,7 +452,16 @@ def __init__(self, t: type[_ST], context: ContextT, *, var: str | None = None):
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.fields_by_var = {m.var: m for m in self.context.values() if m.var}
self.fields_by_var = self._fields_by_var()

def _fields_by_var(self):
fields_by_var = {}
tp_var_map = {(t.type, t.var): m for t, m in self.context.items()}
hints = inspection.cached_type_hints(self.t)
for name, hint in hints.items():
resolved = refs.evaluate(hint)
fields_by_var[name] = tp_var_map[(resolved, name)]
return fields_by_var

def __call__(self, val: _ST) -> MarshalledMappingT:
"""Marshal a structured type into a simple [`dict`][].
Expand Down
13 changes: 11 additions & 2 deletions src/typelib/unmarshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import uuid

from typelib import constants, graph, serdes
from typelib.py import compat, inspection
from typelib.py import compat, inspection, refs

T = tp.TypeVar("T")

Expand Down Expand Up @@ -967,7 +967,16 @@ def __init__(self, t: type[_ST], context: ContextT, *, var: str | None = None):
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.fields_by_var = {m.var: m for m in self.context.values() if m.var}
self.fields_by_var = self._fields_by_var()

def _fields_by_var(self):
fields_by_var = {}
tp_var_map = {(t.type, t.var): m for t, m in self.context.items()}
hints = inspection.cached_type_hints(self.t)
for name, hint in hints.items():
resolved = refs.evaluate(hint)
fields_by_var[name] = tp_var_map[(resolved, name)]
return fields_by_var

def __call__(self, val: tp.Any) -> _ST:
"""Unmarshal a value into the bound type.
Expand Down
21 changes: 21 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,24 @@ class UnionSTDLib:
timestamp: datetime.datetime | None = None
date_time: datetime.datetime | None = None
intstr: int | str = 0


@dataclasses.dataclass
class Parent:
intersection: ParentIntersect
child: Child


@dataclasses.dataclass
class Child:
intersection: ChildIntersect


@dataclasses.dataclass
class ParentIntersect:
a: int


@dataclasses.dataclass
class ChildIntersect:
b: int
8 changes: 8 additions & 0 deletions tests/unit/marshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@
given_input=models.GivenEnum.one,
expected_output=models.GivenEnum.one.value,
),
attrib_conflict=dict(
given_type=models.Parent,
given_input=models.Parent(
intersection=models.ParentIntersect(a=0),
child=models.Child(intersection=models.ChildIntersect(b=0)),
),
expected_output={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},
),
)
def test_marshal(given_type, given_input, expected_output):
# When
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/marshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest

from typelib import graph
from typelib.marshals import routines

from tests import models
Expand Down Expand Up @@ -386,8 +387,12 @@ def test_fixed_tuple_unmarshaller(
@pytest.mark.suite(
context=dict(
given_context={
int: routines.IntegerMarshaller(int, {}, var="value"),
str: routines.StringMarshaller(str, {}, var="field"),
graph.TypeNode(int, var="value"): routines.IntegerMarshaller(
int, {}, var="value"
),
graph.TypeNode(str, var="field"): routines.StringMarshaller(
str, {}, var="field"
),
},
expected_output=dict(field="data", value=1),
),
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/unmarshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@
timestamp=datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
),
),
attrib_conflict=dict(
given_type=models.Parent,
given_input={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},
expected_output=models.Parent(
intersection=models.ParentIntersect(a=0),
child=models.Child(intersection=models.ChildIntersect(b=0)),
),
),
)
def test_unmarshal(given_type, given_input, expected_output):
# When
Expand Down
33 changes: 13 additions & 20 deletions tests/unit/unmarshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest

from typelib import graph
from typelib.unmarshals import routines

from tests import models
Expand Down Expand Up @@ -718,44 +719,36 @@ def test_fixed_tuple_unmarshaller(


@pytest.mark.suite(
dataclass=dict(
given_cls=models.Data,
context=dict(
given_context={
int: routines.NumberUnmarshaller(int, {}, var="value"),
str: routines.StringUnmarshaller(str, {}, var="field"),
graph.TypeNode(int, var="value"): routines.NumberUnmarshaller(
int, {}, var="value"
),
graph.TypeNode(str, var="field"): routines.StringUnmarshaller(
str, {}, var="field"
),
},
),
)
@pytest.mark.suite(
dataclass=dict(
given_cls=models.Data,
expected_output=models.Data(field="data", value=1),
),
vanilla=dict(
given_cls=models.Vanilla,
given_context={
int: routines.NumberUnmarshaller(int, {}, var="value"),
str: routines.StringUnmarshaller(str, {}, var="field"),
},
expected_output=models.Vanilla(field="data", value=1),
),
vanilla_with_hints=dict(
given_cls=models.VanillaWithHints,
given_context={
int: routines.NumberUnmarshaller(int, {}, var="value"),
str: routines.StringUnmarshaller(str, {}, var="field"),
},
expected_output=models.VanillaWithHints(field="data", value=1),
),
named_tuple=dict(
given_cls=models.NTuple,
given_context={
int: routines.NumberUnmarshaller(int, {}, var="value"),
str: routines.StringUnmarshaller(str, {}, var="field"),
},
expected_output=models.NTuple(field="data", value=1),
),
typed_dict=dict(
given_cls=models.TDict,
given_context={
int: routines.NumberUnmarshaller(int, {}, var="value"),
str: routines.StringUnmarshaller(str, {}, var="field"),
},
expected_output=models.TDict(field="data", value=1),
),
)
Expand Down

0 comments on commit 78d4896

Please sign in to comment.