Skip to content

Commit

Permalink
Flatten CoreSchema types to get a single discriminant key (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Mar 17, 2023
1 parent 28ff34e commit 5efeaf9
Show file tree
Hide file tree
Showing 34 changed files with 348 additions and 499 deletions.
19 changes: 1 addition & 18 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,9 @@ def main() -> None:
assert m, f'Unknown schema type: {type_}'
key = m.group(1)
value = get_schema(s)
if key == 'function':
mode = value['fields']['mode']['schema']['expected']
if mode == ['plain']:
key = 'function-plain'
elif mode == ['wrap']:
key = 'function-wrap'
elif key == 'tuple':
if value['fields']['mode']['schema']['expected'] == ['positional']:
key = 'tuple-positional'
else:
key = 'tuple-variable'

choices[key] = value

schema = {
'type': 'tagged-union',
'ref': 'root-schema',
'discriminator': 'self-schema-discriminator',
'choices': choices,
}
schema = {'type': 'tagged-union', 'ref': 'root-schema', 'discriminator': 'type', 'choices': choices}
python_code = (
f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n'
)
Expand Down
73 changes: 35 additions & 38 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,7 @@ def list_schema(


class TuplePositionalSchema(TypedDict, total=False):
type: Required[Literal['tuple']]
mode: Required[Literal['positional']]
type: Required[Literal['tuple-positional']]
items_schema: Required[List[CoreSchema]]
extra_schema: CoreSchema
strict: bool
Expand Down Expand Up @@ -1179,8 +1178,7 @@ def tuple_positional_schema(
serialization: Custom serialization schema
"""
return dict_not_none(
type='tuple',
mode='positional',
type='tuple-positional',
items_schema=list(items_schema),
extra_schema=extra_schema,
strict=strict,
Expand All @@ -1191,8 +1189,7 @@ def tuple_positional_schema(


class TupleVariableSchema(TypedDict, total=False):
type: Required[Literal['tuple']]
mode: Literal['variable']
type: Required[Literal['tuple-variable']]
items_schema: CoreSchema
min_length: int
max_length: int
Expand Down Expand Up @@ -1232,8 +1229,7 @@ def tuple_variable_schema(
serialization: Custom serialization schema
"""
return dict_not_none(
type='tuple',
mode='variable',
type='tuple-variable',
items_schema=items_schema,
min_length=min_length,
max_length=max_length,
Expand Down Expand Up @@ -1509,24 +1505,26 @@ class GeneralValidatorFunctionSchema(TypedDict):
function: GeneralValidatorFunction


class FunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
class _FunctionSchema(TypedDict, total=False):
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
mode: Required[Literal['before', 'after']]
schema: Required[CoreSchema]
ref: str
metadata: Any
serialization: SerSchema


class FunctionBeforeSchema(_FunctionSchema, total=False):
type: Required[Literal['function-before']]


def field_before_validation_function(
function: FieldValidatorFunction,
schema: CoreSchema,
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionBeforeSchema:
"""
Returns a schema that calls a validator function before validating
the provided **model field** schema, e.g.:
Expand Down Expand Up @@ -1556,8 +1554,7 @@ def fn(v: bytes, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='before',
type='function-before',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1573,7 +1570,7 @@ def general_before_validation_function(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionBeforeSchema:
"""
Returns a schema that calls a validator function before validating the provided schema, e.g.:
Expand All @@ -1599,8 +1596,7 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='before',
type='function-before',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1609,14 +1605,18 @@ def fn(v: Any, info: core_schema.ValidationInfo) -> str:
)


class FunctionAfterSchema(_FunctionSchema, total=False):
type: Required[Literal['function-after']]


def field_after_validation_function(
function: FieldValidatorFunction,
schema: CoreSchema,
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionAfterSchema:
"""
Returns a schema that calls a validator function after validating
the provided **model field** schema, e.g.:
Expand Down Expand Up @@ -1646,8 +1646,7 @@ def fn(v: str, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='after',
type='function-after',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1663,7 +1662,7 @@ def general_after_validation_function(
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> FunctionSchema:
) -> FunctionAfterSchema:
"""
Returns a schema that calls a validator function after validating the provided schema, e.g.:
Expand All @@ -1687,8 +1686,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='after',
type='function-after',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand Down Expand Up @@ -1727,9 +1725,8 @@ class GeneralWrapValidatorFunctionSchema(TypedDict):


class WrapFunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
type: Required[Literal['function-wrap']]
function: Required[Union[GeneralWrapValidatorFunctionSchema, FieldWrapValidatorFunctionSchema]]
mode: Required[Literal['wrap']]
schema: Required[CoreSchema]
ref: str
metadata: Any
Expand Down Expand Up @@ -1768,8 +1765,7 @@ def fn(v: str, validator: core_schema.CallableValidator, info: core_schema.Valid
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='wrap',
type='function-wrap',
function={'type': 'general', 'function': function},
schema=schema,
ref=ref,
Expand Down Expand Up @@ -1817,8 +1813,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='wrap',
type='function-wrap',
function={'type': 'field', 'function': function},
schema=schema,
ref=ref,
Expand All @@ -1828,8 +1823,7 @@ def fn(v: bytes, validator: core_schema.CallableValidator, info: core_schema.Mod


class PlainFunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
mode: Required[Literal['plain']]
type: Required[Literal['function-plain']]
function: Required[Union[FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema]]
ref: str
metadata: Any
Expand Down Expand Up @@ -1865,8 +1859,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='plain',
type='function-plain',
function={'type': 'general', 'function': function},
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -1909,8 +1902,7 @@ def fn(v: Any, info: core_schema.ModelFieldValidationInfo) -> str:
serialization: Custom serialization schema
"""
return dict_not_none(
type='function',
mode='plain',
type='function-plain',
function={'type': 'field', 'function': function},
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -3068,7 +3060,8 @@ def definition_reference_schema(
FrozenSetSchema,
GeneratorSchema,
DictSchema,
FunctionSchema,
FunctionAfterSchema,
FunctionBeforeSchema,
WrapFunctionSchema,
PlainFunctionSchema,
WithDefaultSchema,
Expand Down Expand Up @@ -3109,12 +3102,16 @@ def definition_reference_schema(
'is-subclass',
'callable',
'list',
'tuple',
'tuple-positional',
'tuple-variable',
'set',
'frozenset',
'generator',
'dict',
'function',
'function-after',
'function-before',
'function-wrap',
'function-plain',
'default',
'nullable',
'union',
Expand Down
22 changes: 13 additions & 9 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,12 @@ combined_serializer! {
// hence they're here.
Function: super::type_serializers::function::FunctionPlainSerializer;
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
// `TuplePositionalSerializer` & `TupleVariableSerializer` are created by
// `TupleBuilder` based on the `mode` parameter.
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
}
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
find_only: {
super::type_serializers::tuple::TupleBuilder;
super::type_serializers::union::TaggedUnionBuilder;
super::type_serializers::other::ChainBuilder;
super::type_serializers::other::FunctionBuilder;
super::type_serializers::other::CustomErrorBuilder;
super::type_serializers::other::CallBuilder;
super::type_serializers::other::LaxOrStrictBuilder;
Expand All @@ -100,6 +94,10 @@ combined_serializer! {
super::type_serializers::definitions::DefinitionsBuilder;
super::type_serializers::dataclass::DataclassArgsBuilder;
super::type_serializers::dataclass::DataclassBuilder;
super::type_serializers::function::FunctionBeforeSerializerBuilder;
super::type_serializers::function::FunctionAfterSerializerBuilder;
super::type_serializers::function::FunctionPlainSerializerBuilder;
super::type_serializers::function::FunctionWrapSerializerBuilder;
}
// `both` means the struct is added to both the `CombinedSerializer` enum and the match statement in
// `find_serializer` so they can be used via a `type` str.
Expand Down Expand Up @@ -132,6 +130,8 @@ combined_serializer! {
Union: super::type_serializers::union::UnionSerializer;
Literal: super::type_serializers::literal::LiteralSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
}
}

Expand All @@ -150,11 +150,15 @@ impl CombinedSerializer {
Some("function-plain") => {
// `function` is a special case, not included in `find_serializer` since it means something
// different in `schema.type`
return super::type_serializers::function::FunctionPlainSerializer::new_combined(ser_schema)
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
return super::type_serializers::function::FunctionPlainSerializer::build(
ser_schema,
config,
build_context,
)
.map_err(|err| py_error_type!("Error building `function-plain` serializer:\n {}", err));
}
Some("function-wrap") => {
return super::type_serializers::function::FunctionWrapSerializer::new_combined(
return super::type_serializers::function::FunctionWrapSerializer::build(
ser_schema,
config,
build_context,
Expand Down
Loading

0 comments on commit 5efeaf9

Please sign in to comment.