Skip to content

Commit

Permalink
switch all schema typeddicts to total=False
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 16, 2022
1 parent f39be3c commit 6eea0ed
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 96 deletions.
6 changes: 0 additions & 6 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def type_dict_schema(typed_dict) -> dict[str, Any]:
schema = None
if type(field_type) == ForwardRef:
fr_arg = field_type.__forward_arg__
fr_arg, matched = re.subn(r'NotRequired\[(.+)]', r'\1', fr_arg)
if matched:
required = False

fr_arg, matched = re.subn(r'Required\[(.+)]', r'\1', fr_arg)
if matched:
Expand All @@ -116,9 +113,6 @@ def type_dict_schema(typed_dict) -> dict[str, Any]:
if get_origin(field_type) == core_schema.Required:
required = True
field_type = field_type.__args__[0]
if get_origin(field_type) == core_schema.NotRequired:
required = False
field_type = field_type.__args__[0]

schema = get_schema(field_type)

Expand Down
182 changes: 92 additions & 90 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Protocol, Required
from typing_extensions import Protocol, Required
else:
from typing import NotRequired, Protocol, Required
from typing import Protocol, Required

if sys.version_info < (3, 9):
from typing_extensions import Literal, TypedDict
Expand Down Expand Up @@ -44,22 +44,20 @@ class CoreConfig(TypedDict, total=False):
allow_inf_nan: bool # default: True


class AnySchema(TypedDict):
type: Literal['any']
extra: NotRequired[Any]
class AnySchema(TypedDict, total=False):
type: Required[Literal['any']]
ref: str
extra: Any


def any_schema(*, extra: Any = None) -> AnySchema:
if extra is None:
return {'type': 'any'}
else:
return {'type': 'any', 'extra': extra}
def any_schema(*, ref: str | None = None, extra: Any = None) -> AnySchema:
return dict_not_none(type='any', ref=ref, extra=extra)


class NoneSchema(TypedDict):
type: Literal['none']
ref: NotRequired[str]
extra: NotRequired[Any]
class NoneSchema(TypedDict, total=False):
type: Required[Literal['none']]
ref: str
extra: Any


def none_schema(*, ref: str | None = None, extra: Any = None) -> NoneSchema:
Expand Down Expand Up @@ -301,11 +299,11 @@ def timedelta_schema(
return dict_not_none(type='timedelta', strict=strict, le=le, ge=ge, lt=lt, gt=gt, ref=ref, extra=extra)


class LiteralSchema(TypedDict):
type: Literal['literal']
expected: List[Any]
ref: NotRequired[str]
extra: NotRequired[Any]
class LiteralSchema(TypedDict, total=False):
type: Required[Literal['literal']]
expected: Required[List[Any]]
ref: str
extra: Any


def literal_schema(*expected: Any, ref: str | None = None, extra: Any = None) -> LiteralSchema:
Expand All @@ -330,12 +328,14 @@ def is_instance_schema(
return dict_not_none(type='is-instance', cls=cls, json_types=json_types, ref=ref, extra=extra)


class CallableSchema(TypedDict):
type: Literal['callable']
class CallableSchema(TypedDict, total=False):
type: Required[Literal['callable']]
ref: str
extra: Any


def callable_schema() -> CallableSchema:
return dict_not_none(type='callable')
def callable_schema(*, ref: str | None = None, extra: Any = None) -> CallableSchema:
return dict_not_none(type='callable', ref=ref, extra=extra)


class ListSchema(TypedDict, total=False):
Expand Down Expand Up @@ -551,13 +551,13 @@ def __call__(
...


class FunctionSchema(TypedDict):
type: Literal['function']
mode: Literal['before', 'after']
function: ValidatorFunction
schema: CoreSchema
ref: NotRequired[str]
extra: NotRequired[Any]
class FunctionSchema(TypedDict, total=False):
type: Required[Literal['function']]
mode: Required[Literal['before', 'after']]
function: Required[ValidatorFunction]
schema: Required[CoreSchema]
ref: str
extra: Any


def function_before_schema(
Expand Down Expand Up @@ -591,13 +591,13 @@ def __call__(
...


class FunctionWrapSchema(TypedDict):
type: Literal['function']
mode: Literal['wrap']
function: WrapValidatorFunction
schema: CoreSchema
ref: NotRequired[str]
extra: NotRequired[Any]
class FunctionWrapSchema(TypedDict, total=False):
type: Required[Literal['function']]
mode: Required[Literal['wrap']]
function: Required[WrapValidatorFunction]
schema: Required[CoreSchema]
ref: str
extra: Any


def function_wrap_schema(
Expand All @@ -606,12 +606,12 @@ def function_wrap_schema(
return dict_not_none(type='function', mode='wrap', function=function, schema=schema, ref=ref, extra=extra)


class FunctionPlainSchema(TypedDict):
type: Literal['function']
mode: Literal['plain']
function: ValidatorFunction
ref: NotRequired[str]
extra: NotRequired[Any]
class FunctionPlainSchema(TypedDict, total=False):
type: Required[Literal['function']]
mode: Required[Literal['plain']]
function: Required[ValidatorFunction]
ref: str
extra: Any


def function_plain_schema(
Expand Down Expand Up @@ -710,16 +710,18 @@ def union_schema(
)


class TaggedUnionSchema(TypedDict):
type: Literal['tagged-union']
choices: Dict[str, CoreSchema]
discriminator: Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[str]]]
custom_error_kind: NotRequired[str]
custom_error_message: NotRequired[str]
custom_error_context: NotRequired[Dict[str, Union[str, int, float]]]
strict: NotRequired[bool]
ref: NotRequired[str]
extra: NotRequired[Any]
class TaggedUnionSchema(TypedDict, total=False):
type: Required[Literal['tagged-union']]
choices: Required[Dict[str, CoreSchema]]
discriminator: Required[
Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[str]]]
]
custom_error_kind: str
custom_error_message: str
custom_error_context: Dict[str, Union[str, int, float]]
strict: bool
ref: str
extra: Any


def tagged_union_schema(
Expand All @@ -746,11 +748,11 @@ def tagged_union_schema(
)


class ChainSchema(TypedDict):
type: Literal['chain']
steps: List[CoreSchema]
ref: NotRequired[str]
extra: NotRequired[Any]
class ChainSchema(TypedDict, total=False):
type: Required[Literal['chain']]
steps: Required[List[CoreSchema]]
ref: str
extra: Any


def chain_schema(*steps: CoreSchema, ref: str | None = None, extra: Any = None) -> ChainSchema:
Expand Down Expand Up @@ -817,15 +819,15 @@ def typed_dict_schema(
)


class NewClassSchema(TypedDict):
type: Literal['new-class']
cls: Type[Any]
schema: CoreSchema
call_after_init: NotRequired[str]
strict: NotRequired[bool]
ref: NotRequired[str]
extra: NotRequired[Any]
config: NotRequired[CoreConfig]
class NewClassSchema(TypedDict, total=False):
type: Required[Literal['new-class']]
cls: Required[Type[Any]]
schema: Required[CoreSchema]
call_after_init: str
strict: bool
ref: str
extra: Any
config: CoreConfig


def new_class_schema(
Expand Down Expand Up @@ -889,13 +891,13 @@ def arguments_schema(
)


class CallSchema(TypedDict):
type: Literal['call']
arguments_schema: CoreSchema
function: Callable[..., Any]
return_schema: NotRequired[CoreSchema]
ref: NotRequired[str]
extra: NotRequired[Any]
class CallSchema(TypedDict, total=False):
type: Required[Literal['call']]
arguments_schema: Required[CoreSchema]
function: Required[Callable[..., Any]]
return_schema: CoreSchema
ref: str
extra: Any


def call_schema(
Expand All @@ -911,23 +913,23 @@ def call_schema(
)


class RecursiveReferenceSchema(TypedDict):
type: Literal['recursive-ref']
schema_ref: str
class RecursiveReferenceSchema(TypedDict, total=False):
type: Required[Literal['recursive-ref']]
schema_ref: Required[str]


def recursive_reference_schema(schema_ref: str) -> RecursiveReferenceSchema:
return {'type': 'recursive-ref', 'schema_ref': schema_ref}


class CustomErrorSchema(TypedDict):
type: Literal['custom_error']
schema: CoreSchema
custom_error_kind: str
custom_error_message: NotRequired[str]
custom_error_context: NotRequired[Dict[str, Union[str, int, float]]]
ref: NotRequired[str]
extra: NotRequired[Any]
class CustomErrorSchema(TypedDict, total=False):
type: Required[Literal['custom_error']]
schema: Required[CoreSchema]
custom_error_kind: Required[str]
custom_error_message: str
custom_error_context: Dict[str, Union[str, int, float]]
ref: str
extra: Any


def custom_error_schema(
Expand All @@ -950,11 +952,11 @@ def custom_error_schema(
)


class JsonSchema(TypedDict):
type: Literal['json']
schema: NotRequired[CoreSchema]
ref: NotRequired[str]
extra: NotRequired[Any]
class JsonSchema(TypedDict, total=False):
type: Required[Literal['json']]
schema: CoreSchema
ref: str
extra: Any


def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, extra: Any = None) -> JsonSchema:
Expand Down

0 comments on commit 6eea0ed

Please sign in to comment.