Skip to content

Commit d1a067c

Browse files
fix(types): handle more discriminated union shapes (#213)
1 parent 7efe8a4 commit d1a067c

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/writerai/_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ._constants import RAW_RESPONSE_HEADER
6767

6868
if TYPE_CHECKING:
69-
from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema
69+
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
7070

7171
__all__ = ["BaseModel", "GenericModel"]
7272

@@ -647,15 +647,18 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
647647

648648
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
649649
schema = model.__pydantic_core_schema__
650+
if schema["type"] == "definitions":
651+
schema = schema["schema"]
652+
650653
if schema["type"] != "model":
651654
return None
652655

656+
schema = cast("ModelSchema", schema)
653657
fields_schema = schema["schema"]
654658
if fields_schema["type"] != "model-fields":
655659
return None
656660

657661
fields_schema = cast("ModelFieldsSchema", fields_schema)
658-
659662
field = fields_schema["fields"].get(field_name)
660663
if not field:
661664
return None

tests/test_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,35 @@ class Model(BaseModel):
854854
m = construct_type(value={"cls": "foo"}, type_=Model)
855855
assert isinstance(m, Model)
856856
assert isinstance(m.cls, str)
857+
858+
859+
def test_discriminated_union_case() -> None:
860+
class A(BaseModel):
861+
type: Literal["a"]
862+
863+
data: bool
864+
865+
class B(BaseModel):
866+
type: Literal["b"]
867+
868+
data: List[Union[A, object]]
869+
870+
class ModelA(BaseModel):
871+
type: Literal["modelA"]
872+
873+
data: int
874+
875+
class ModelB(BaseModel):
876+
type: Literal["modelB"]
877+
878+
required: str
879+
880+
data: Union[A, B]
881+
882+
# when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required`
883+
m = construct_type(
884+
value={"type": "modelB", "data": {"type": "a", "data": True}},
885+
type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]),
886+
)
887+
888+
assert isinstance(m, ModelB)

0 commit comments

Comments
 (0)