Skip to content

Commit b627686

Browse files
fix(parsing): correctly handle nested discriminated unions
1 parent ef5bc36 commit b627686

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

src/llama_api_client/_models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import os
44
import inspect
5-
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast
5+
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
66
from datetime import date, datetime
77
from typing_extensions import (
8+
List,
89
Unpack,
910
Literal,
1011
ClassVar,
@@ -366,7 +367,7 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
366367
if type_ is None:
367368
raise RuntimeError(f"Unexpected field type is None for {key}")
368369

369-
return construct_type(value=value, type_=type_)
370+
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
370371

371372

372373
def is_basemodel(type_: type) -> bool:
@@ -420,7 +421,7 @@ def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
420421
return cast(_T, construct_type(value=value, type_=type_))
421422

422423

423-
def construct_type(*, value: object, type_: object) -> object:
424+
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
424425
"""Loose coercion to the expected type with construction of nested values.
425426
426427
If the given value does not match the expected type then it is returned as-is.
@@ -438,8 +439,10 @@ def construct_type(*, value: object, type_: object) -> object:
438439
type_ = type_.__value__ # type: ignore[unreachable]
439440

440441
# unwrap `Annotated[T, ...]` -> `T`
441-
if is_annotated_type(type_):
442-
meta: tuple[Any, ...] = get_args(type_)[1:]
442+
if metadata is not None:
443+
meta: tuple[Any, ...] = tuple(metadata)
444+
elif is_annotated_type(type_):
445+
meta = get_args(type_)[1:]
443446
type_ = extract_type_arg(type_, 0)
444447
else:
445448
meta = tuple()

tests/test_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,48 @@ class ModelB(BaseModel):
889889
)
890890

891891
assert isinstance(m, ModelB)
892+
893+
894+
def test_nested_discriminated_union() -> None:
895+
class InnerType1(BaseModel):
896+
type: Literal["type_1"]
897+
898+
class InnerModel(BaseModel):
899+
inner_value: str
900+
901+
class InnerType2(BaseModel):
902+
type: Literal["type_2"]
903+
some_inner_model: InnerModel
904+
905+
class Type1(BaseModel):
906+
base_type: Literal["base_type_1"]
907+
value: Annotated[
908+
Union[
909+
InnerType1,
910+
InnerType2,
911+
],
912+
PropertyInfo(discriminator="type"),
913+
]
914+
915+
class Type2(BaseModel):
916+
base_type: Literal["base_type_2"]
917+
918+
T = Annotated[
919+
Union[
920+
Type1,
921+
Type2,
922+
],
923+
PropertyInfo(discriminator="base_type"),
924+
]
925+
926+
model = construct_type(
927+
type_=T,
928+
value={
929+
"base_type": "base_type_1",
930+
"value": {
931+
"type": "type_2",
932+
},
933+
},
934+
)
935+
assert isinstance(model, Type1)
936+
assert isinstance(model.value, InnerType2)

0 commit comments

Comments
 (0)