Closed
Description
The tagged_union_schema
function in pydantic_core.core_schema is currently typed as
def tagged_union_schema(
choices: Dict[Hashable, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable],
*,
custom_error_type: str | None = None,
custom_error_message: str | None = None,
custom_error_context: dict[str, int | str | float] | None = None,
strict: bool | None = None,
from_attributes: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> TaggedUnionSchema:
But the key type of Dict is invariant, so even the tagged union example code makes pyright and mypy throw errors as soon as you eg save the dict literal to a variable:
from pydantic_core import SchemaValidator, core_schema
apple_schema = core_schema.typed_dict_schema(
{
'foo': core_schema.typed_dict_field(core_schema.str_schema()),
'bar': core_schema.typed_dict_field(core_schema.int_schema()),
}
)
banana_schema = core_schema.typed_dict_schema(
{
'foo': core_schema.typed_dict_field(core_schema.str_schema()),
'spam': core_schema.typed_dict_field(
core_schema.list_schema(items_schema=core_schema.int_schema())
),
}
)
choices = {
'apple': apple_schema,
'banana': banana_schema,
}
schema = core_schema.tagged_union_schema(
choices=choices,
discriminator='foo',
)
v = SchemaValidator(schema)
assert v.validate_python({'foo': 'apple', 'bar': '123'}) == {'foo': 'apple', 'bar': 123}
assert v.validate_python({'foo': 'banana', 'spam': [1, 2, 3]}) == {
'foo': 'banana',
'spam': [1, 2, 3],
}
mypy error:
error: Argument "choices" to "tagged_union_schema" has incompatible type "Dict[str, TypedDict({'type': Literal['typed-dict'], 'fields': Dict[str, TypedDictField], 'computed_fields'?: List[ComputedField], 'strict'?: bool, 'extra_validator'?: Mapping[str, Any], 'extra_behavior'?: Literal['allow', 'forbid', 'ignore'], 'total'?: bool, 'populate_by_name'?: bool, 'ref'?: str, 'metadata'?: Any, 'serialization'?: Union[SimpleSerSchema, PlainSerializerFunctionSerSchema, WrapSerializerFunctionSerSchema, FormatSerSchema, ToStringSerSchema, ModelSerSchema], 'config'?: CoreConfig})]"; expected "Dict[Hashable, Mapping[str, Any]]" [arg-type]
pyright error:
Argument of type "dict[str, CoreSchema]" cannot be assigned to parameter "choices" of type "Dict[Hashable, CoreSchema]" in function "tagged_union_schema"
"dict[str, CoreSchema]" is incompatible with "Dict[Hashable, CoreSchema]"
Type parameter "_KT@dict" is invariant, but "str" is not the same as "Hashable"
An easy way to solve this would be to make the type be a TypeVar that's bound to Hashable instead, like this:
H = TypeVar("H", bound=Hashable)
def tagged_union_schema(
choices: dict[H, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], H],
*,
custom_error_type: str | None = None,
custom_error_message: str | None = None,
custom_error_context: dict[str, int | str | float] | None = None,
strict: bool | None = None,
from_attributes: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> TaggedUnionSchema:
Selected Assignee: @lig