Skip to content

Commit

Permalink
Change args to list (#508)
Browse files Browse the repository at this point in the history
* change first argument from *args to list for union_schema

* change first  argument for arguments_schem from *args to list

* change first argument for chain_schema from *args to list

* change first argument of tuple_positional_schem from *args to list

* change first argument of literal_schema from *args to list

* update docstring of union_schema
  • Loading branch information
realDragonium authored Apr 1, 2023
1 parent e65b46f commit 227292b
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 134 deletions.
38 changes: 18 additions & 20 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,15 +1043,15 @@ class LiteralSchema(TypedDict, total=False):


def literal_schema(
*expected: Any, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None
expected: list[Any], ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None
) -> LiteralSchema:
"""
Returns a schema that matches a literal value, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.literal_schema('hello', 'world')
schema = core_schema.literal_schema(['hello', 'world'])
v = SchemaValidator(schema)
assert v.validate_python('hello') == 'hello'
```
Expand All @@ -1062,9 +1062,7 @@ def literal_schema(
metadata: See [TODO] for details
serialization: Custom serialization schema
"""
return dict_not_none(
type='literal', expected=list(expected), ref=ref, metadata=metadata, serialization=serialization
)
return dict_not_none(type='literal', expected=expected, ref=ref, metadata=metadata, serialization=serialization)


# must match input/parse_json.rs::JsonType::try_from
Expand Down Expand Up @@ -1284,7 +1282,7 @@ class TuplePositionalSchema(TypedDict, total=False):


def tuple_positional_schema(
*items_schema: CoreSchema,
items_schema: list[CoreSchema],
extra_schema: CoreSchema | None = None,
strict: bool | None = None,
ref: str | None = None,
Expand All @@ -1298,7 +1296,7 @@ def tuple_positional_schema(
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.tuple_positional_schema(
core_schema.int_schema(), core_schema.str_schema()
[core_schema.int_schema(), core_schema.str_schema()]
)
v = SchemaValidator(schema)
assert v.validate_python((1, 'hello')) == (1, 'hello')
Expand All @@ -1317,7 +1315,7 @@ def tuple_positional_schema(
"""
return dict_not_none(
type='tuple-positional',
items_schema=list(items_schema),
items_schema=items_schema,
extra_schema=extra_schema,
strict=strict,
ref=ref,
Expand Down Expand Up @@ -2194,7 +2192,7 @@ class UnionSchema(TypedDict, total=False):


def union_schema(
*choices: CoreSchema,
choices: list[CoreSchema],
auto_collapse: bool | None = None,
custom_error_type: str | None = None,
custom_error_message: str | None = None,
Expand All @@ -2210,14 +2208,14 @@ def union_schema(
```py
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.union_schema(core_schema.str_schema(), core_schema.int_schema())
schema = core_schema.union_schema([core_schema.str_schema(), core_schema.int_schema()])
v = SchemaValidator(schema)
assert v.validate_python('hello') == 'hello'
assert v.validate_python(1) == 1
```
Args:
*choices: The schemas to match
choices: The schemas to match
auto_collapse: whether to automatically collapse unions with one element to the inner validator, default true
custom_error_type: The custom error type to use if the validation fails
custom_error_message: The custom error message to use if the validation fails
Expand All @@ -2229,7 +2227,7 @@ def union_schema(
"""
return dict_not_none(
type='union',
choices=list(choices),
choices=choices,
auto_collapse=auto_collapse,
custom_error_type=custom_error_type,
custom_error_message=custom_error_message,
Expand Down Expand Up @@ -2349,7 +2347,7 @@ class ChainSchema(TypedDict, total=False):


def chain_schema(
*steps: CoreSchema, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None
steps: list[CoreSchema], ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None
) -> ChainSchema:
"""
Returns a schema that chains the provided validation schemas, e.g.:
Expand All @@ -2363,7 +2361,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
fn_schema = core_schema.general_plain_validator_function(function=fn)
schema = core_schema.chain_schema(
fn_schema, fn_schema, fn_schema, core_schema.str_schema()
[fn_schema, fn_schema, fn_schema, core_schema.str_schema()]
)
v = SchemaValidator(schema)
assert v.validate_python('hello') == 'hello world world world'
Expand All @@ -2375,7 +2373,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str:
metadata: See [TODO] for details
serialization: Custom serialization schema
"""
return dict_not_none(type='chain', steps=list(steps), ref=ref, metadata=metadata, serialization=serialization)
return dict_not_none(type='chain', steps=steps, ref=ref, metadata=metadata, serialization=serialization)


class LaxOrStrictSchema(TypedDict, total=False):
Expand Down Expand Up @@ -2846,7 +2844,7 @@ def arguments_parameter(
param = core_schema.arguments_parameter(
name='a', schema=core_schema.str_schema(), mode='positional_only'
)
schema = core_schema.arguments_schema(param)
schema = core_schema.arguments_schema([param])
v = SchemaValidator(schema)
assert v.validate_python(('hello',)) == (('hello',), {})
```
Expand All @@ -2872,7 +2870,7 @@ class ArgumentsSchema(TypedDict, total=False):


def arguments_schema(
*arguments: ArgumentsParameter,
arguments: list[ArgumentsParameter],
populate_by_name: bool | None = None,
var_args_schema: CoreSchema | None = None,
var_kwargs_schema: CoreSchema | None = None,
Expand All @@ -2892,7 +2890,7 @@ def arguments_schema(
param_b = core_schema.arguments_parameter(
name='b', schema=core_schema.bool_schema(), mode='positional_only'
)
schema = core_schema.arguments_schema(param_a, param_b)
schema = core_schema.arguments_schema([param_a, param_b])
v = SchemaValidator(schema)
assert v.validate_python(('hello', True)) == (('hello', True), {})
```
Expand All @@ -2908,7 +2906,7 @@ def arguments_schema(
"""
return dict_not_none(
type='arguments',
arguments_schema=list(arguments),
arguments_schema=arguments,
populate_by_name=populate_by_name,
var_args_schema=var_args_schema,
var_kwargs_schema=var_kwargs_schema,
Expand Down Expand Up @@ -2949,7 +2947,7 @@ def call_schema(
param_b = core_schema.arguments_parameter(
name='b', schema=core_schema.bool_schema(), mode='positional_only'
)
args_schema = core_schema.arguments_schema(param_a, param_b)
args_schema = core_schema.arguments_schema([param_a, param_b])
schema = core_schema.call_schema(
arguments=args_schema,
Expand Down
38 changes: 22 additions & 16 deletions tests/serializers/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def test_repeated_ref():
with pytest.raises(SchemaError, match='SchemaError: Duplicate ref: `foobar`'):
SchemaSerializer(
core_schema.tuple_positional_schema(
core_schema.int_schema(ref='foobar'),
core_schema.definition_reference_schema('foobar'),
core_schema.int_schema(ref='foobar'),
[
core_schema.int_schema(ref='foobar'),
core_schema.definition_reference_schema('foobar'),
core_schema.int_schema(ref='foobar'),
]
)
)

Expand All @@ -56,11 +58,13 @@ def test_repeat_after():
with pytest.raises(SchemaError, match='SchemaError: Duplicate ref: `foobar`'):
SchemaSerializer(
core_schema.tuple_positional_schema(
core_schema.definitions_schema(
core_schema.list_schema(core_schema.definition_reference_schema('foobar')),
[core_schema.int_schema(ref='foobar')],
),
core_schema.int_schema(ref='foobar'),
[
core_schema.definitions_schema(
core_schema.list_schema(core_schema.definition_reference_schema('foobar')),
[core_schema.int_schema(ref='foobar')],
),
core_schema.int_schema(ref='foobar'),
]
)
)

Expand Down Expand Up @@ -94,15 +98,17 @@ def test_deep():
def test_use_after():
v = SchemaSerializer(
core_schema.tuple_positional_schema(
core_schema.definitions_schema(
[
core_schema.definitions_schema(
core_schema.definition_reference_schema('foobar'),
[
core_schema.int_schema(
ref='foobar', serialization=core_schema.to_string_ser_schema(when_used='always')
)
],
),
core_schema.definition_reference_schema('foobar'),
[
core_schema.int_schema(
ref='foobar', serialization=core_schema.to_string_ser_schema(when_used='always')
)
],
),
core_schema.definition_reference_schema('foobar'),
]
)
)
assert v.to_python((1, 2)) == ('1', '2')
2 changes: 1 addition & 1 deletion tests/serializers/test_list_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_tuple_pos_dict_key():
s = SchemaSerializer(
core_schema.dict_schema(
core_schema.tuple_positional_schema(
core_schema.int_schema(), core_schema.str_schema(), extra_schema=core_schema.int_schema()
[core_schema.int_schema(), core_schema.str_schema()], extra_schema=core_schema.int_schema()
),
core_schema.int_schema(),
)
Expand Down
8 changes: 4 additions & 4 deletions tests/serializers/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_int_literal():
s = SchemaSerializer(core_schema.literal_schema(1, 2, 3))
s = SchemaSerializer(core_schema.literal_schema([1, 2, 3]))
r = plain_repr(s)
assert 'expected_int:{' in r
assert 'expected_str:{}' in r
Expand All @@ -25,7 +25,7 @@ def test_int_literal():


def test_str_literal():
s = SchemaSerializer(core_schema.literal_schema('a', 'b', 'c'))
s = SchemaSerializer(core_schema.literal_schema(['a', 'b', 'c']))
r = plain_repr(s)
assert 'expected_str:{' in r
assert 'expected_int:{}' in r
Expand All @@ -44,7 +44,7 @@ def test_str_literal():


def test_other_literal():
s = SchemaSerializer(core_schema.literal_schema('a', 1))
s = SchemaSerializer(core_schema.literal_schema(['a', 1]))
assert 'expected_int:{1},expected_str:{"a"},expected_py:None' in plain_repr(s)

assert s.to_python('a') == 'a'
Expand All @@ -60,4 +60,4 @@ def test_other_literal():

def test_empty_literal():
with pytest.raises(SchemaError, match='`expected` should have length > 0'):
SchemaSerializer(core_schema.literal_schema())
SchemaSerializer(core_schema.literal_schema([]))
10 changes: 6 additions & 4 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ class DataClass:
def test_dataclass():
schema = core_schema.call_schema(
core_schema.arguments_schema(
core_schema.arguments_parameter('foo', core_schema.int_schema()),
core_schema.arguments_parameter('bar', core_schema.str_schema()),
core_schema.arguments_parameter('spam', core_schema.bytes_schema(), mode='keyword_only'),
core_schema.arguments_parameter('frog', core_schema.int_schema(), mode='keyword_only'),
[
core_schema.arguments_parameter('foo', core_schema.int_schema()),
core_schema.arguments_parameter('bar', core_schema.str_schema()),
core_schema.arguments_parameter('spam', core_schema.bytes_schema(), mode='keyword_only'),
core_schema.arguments_parameter('frog', core_schema.int_schema(), mode='keyword_only'),
]
),
DataClass,
serialization=core_schema.model_ser_schema(
Expand Down
2 changes: 1 addition & 1 deletion tests/serializers/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_chain():
s = SchemaSerializer(core_schema.chain_schema(core_schema.str_schema(), core_schema.int_schema()))
s = SchemaSerializer(core_schema.chain_schema([core_schema.str_schema(), core_schema.int_schema()]))

# insert_assert(plain_repr(s))
assert plain_repr(s) == 'SchemaSerializer(serializer=Int(IntSerializer),slots=[])'
Expand Down
Loading

0 comments on commit 227292b

Please sign in to comment.