Skip to content

tagged_union_schema type is incorrect #925

Closed
@ImogenBits

Description

@ImogenBits

The tagged_union_schema function in pydantic_core.core_schema is currently typed as

def tagged_union_schema(
    choices: Dict[Hashable, CoreSchema],
    discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable],
    *,
    custom_error_type: str | None = None,
    custom_error_message: str | None = None,
    custom_error_context: dict[str, int | str | float] | None = None,
    strict: bool | None = None,
    from_attributes: bool | None = None,
    ref: str | None = None,
    metadata: Any = None,
    serialization: SerSchema | None = None,
) -> TaggedUnionSchema:

But the key type of Dict is invariant, so even the tagged union example code makes pyright and mypy throw errors as soon as you eg save the dict literal to a variable:

from pydantic_core import SchemaValidator, core_schema

apple_schema = core_schema.typed_dict_schema(
    {
        'foo': core_schema.typed_dict_field(core_schema.str_schema()),
        'bar': core_schema.typed_dict_field(core_schema.int_schema()),
    }
)
banana_schema = core_schema.typed_dict_schema(
    {
        'foo': core_schema.typed_dict_field(core_schema.str_schema()),
        'spam': core_schema.typed_dict_field(
            core_schema.list_schema(items_schema=core_schema.int_schema())
        ),
    }
)
choices = {
    'apple': apple_schema,
    'banana': banana_schema,
}
schema = core_schema.tagged_union_schema(
    choices=choices,
    discriminator='foo',
)
v = SchemaValidator(schema)
assert v.validate_python({'foo': 'apple', 'bar': '123'}) == {'foo': 'apple', 'bar': 123}
assert v.validate_python({'foo': 'banana', 'spam': [1, 2, 3]}) == {
    'foo': 'banana',
    'spam': [1, 2, 3],
}

mypy error:

error: Argument "choices" to "tagged_union_schema" has incompatible type "Dict[str, TypedDict({'type': Literal['typed-dict'], 'fields': Dict[str, TypedDictField], 'computed_fields'?: List[ComputedField], 'strict'?: bool, 'extra_validator'?: Mapping[str, Any], 'extra_behavior'?: Literal['allow', 'forbid', 'ignore'], 'total'?: bool, 'populate_by_name'?: bool, 'ref'?: str, 'metadata'?: Any, 'serialization'?: Union[SimpleSerSchema, PlainSerializerFunctionSerSchema, WrapSerializerFunctionSerSchema, FormatSerSchema, ToStringSerSchema, ModelSerSchema], 'config'?: CoreConfig})]"; expected "Dict[Hashable, Mapping[str, Any]]"  [arg-type] 

pyright error:

Argument of type "dict[str, CoreSchema]" cannot be assigned to parameter "choices" of type "Dict[Hashable, CoreSchema]" in function "tagged_union_schema"
  "dict[str, CoreSchema]" is incompatible with "Dict[Hashable, CoreSchema]"
    Type parameter "_KT@dict" is invariant, but "str" is not the same as "Hashable"

An easy way to solve this would be to make the type be a TypeVar that's bound to Hashable instead, like this:

H = TypeVar("H", bound=Hashable)
def tagged_union_schema(
    choices: dict[H, CoreSchema],
    discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], H],
    *,
    custom_error_type: str | None = None,
    custom_error_message: str | None = None,
    custom_error_context: dict[str, int | str | float] | None = None,
    strict: bool | None = None,
    from_attributes: bool | None = None,
    ref: str | None = None,
    metadata: Any = None,
    serialization: SerSchema | None = None,
) -> TaggedUnionSchema:

Selected Assignee: @lig

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions