From 66f289041b67976fe20be76112477bc05546339a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 5 Oct 2022 12:18:03 +0100 Subject: [PATCH] `is_instance` JSON support (#278) * is_instance json support * add benchmark, remove Option from json_type * remove defult impl of is_instance * confirm tuples work too * test for json string not is_instance --- generate_self_schema.py | 4 +- pydantic_core/core_schema.py | 20 +++++-- src/errors/location.rs | 6 +- src/input/input_abstract.rs | 4 +- src/input/input_json.rs | 29 ++++++++++ src/input/input_python.rs | 2 +- src/input/mod.rs | 2 +- src/input/parse_json.rs | 43 ++++++++++++++- src/validators/is_instance.rs | 12 +++- tests/benchmarks/test_micro_benchmarks.py | 12 ++++ tests/validators/test_is_instance.py | 67 ++++++++++++++++++++++- 11 files changed, 182 insertions(+), 19 deletions(-) diff --git a/generate_self_schema.py b/generate_self_schema.py index a33c82ae5..02537bf23 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -11,7 +11,7 @@ from collections.abc import Callable from datetime import date, datetime, time, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Type, Union +from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union from typing_extensions import get_args, get_origin, is_typeddict @@ -67,6 +67,8 @@ def get_schema(obj) -> core_schema.CoreSchema: return {'type': 'literal', 'expected': expected} elif issubclass(origin, List): return {'type': 'list', 'items_schema': get_schema(obj.__args__[0])} + elif issubclass(origin, Set): + return {'type': 'set', 'items_schema': get_schema(obj.__args__[0])} elif issubclass(origin, Dict): return { 'type': 'dict', diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index e17ac27ff..b32ea604d 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -2,7 +2,7 @@ import sys from datetime import date, datetime, time, timedelta -from typing import Any, Callable, Dict, List, Optional, Type, Union, overload +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union, overload if sys.version_info < (3, 11): from typing_extensions import NotRequired, Protocol, Required @@ -278,13 +278,21 @@ def literal_schema(*expected: Any, ref: str | None = None) -> LiteralSchema: return dict_not_none(type='literal', expected=expected, ref=ref) -class IsInstanceSchema(TypedDict): - type: Literal['is-instance'] - cls: Type[Any] +# must match input/parse_json.rs::JsonType::try_from +JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict'] + + +class IsInstanceSchema(TypedDict, total=False): + type: Required[Literal['is-instance']] + cls: Required[Type[Any]] + json_types: Set[JsonType] + ref: str -def is_instance_schema(cls: Type[Any]) -> IsInstanceSchema: - return dict_not_none(type='is-instance', cls=cls) +def is_instance_schema( + cls: Type[Any], *, json_types: Set[JsonType] | None = None, ref: str | None = None +) -> IsInstanceSchema: + return dict_not_none(type='is-instance', cls=cls, json_types=json_types, ref=ref) class CallableSchema(TypedDict): diff --git a/src/errors/location.rs b/src/errors/location.rs index 28c623eea..6402abff8 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -50,8 +50,10 @@ impl ToPyObject for LocItem { } } -impl LocItem { - pub fn try_from(value: &PyAny) -> PyResult { +impl TryFrom<&PyAny> for LocItem { + type Error = PyErr; + + fn try_from(value: &PyAny) -> PyResult { if let Ok(str) = value.extract::() { Ok(str.into()) } else { diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 2ffd623f1..31f1260a7 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -34,9 +34,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { None } - fn is_instance(&self, _class: &PyType) -> PyResult { - Ok(false) - } + fn is_instance(&self, class: &PyType, json_mask: u8) -> PyResult; fn callable(&self) -> bool { false diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 448b21e8b..da2609f8f 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,4 +1,8 @@ +use pyo3::prelude::*; +use pyo3::types::PyType; + use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult}; +use crate::input::JsonType; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, @@ -30,6 +34,23 @@ impl<'a> Input<'a> for JsonInput { matches!(self, JsonInput::Null) } + fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult { + if json_mask == 0 { + Ok(false) + } else { + let json_type: JsonType = match self { + JsonInput::Null => JsonType::Null, + JsonInput::Bool(_) => JsonType::Bool, + JsonInput::Int(_) => JsonType::Int, + JsonInput::Float(_) => JsonType::Float, + JsonInput::String(_) => JsonType::String, + JsonInput::Array(_) => JsonType::Array, + JsonInput::Object(_) => JsonType::Object, + }; + Ok(json_type.matches(json_mask)) + } + } + fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { match self { JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()), @@ -284,6 +305,14 @@ impl<'a> Input<'a> for String { false } + fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult { + if json_mask == 0 { + Ok(false) + } else { + Ok(JsonType::String.matches(json_mask)) + } + } + #[cfg_attr(has_no_coverage, no_coverage)] fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { Err(ValError::new(ErrorKind::ArgumentsType, self)) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 41806b63b..b9b8f0305 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -81,7 +81,7 @@ impl<'a> Input<'a> for PyAny { self.getattr(name).ok() } - fn is_instance(&self, class: &PyType) -> PyResult { + fn is_instance(&self, class: &PyType, _json_mask: u8) -> PyResult { self.is_instance(class) } diff --git a/src/input/mod.rs b/src/input/mod.rs index 9ebf0162d..6e70f3779 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -10,7 +10,7 @@ mod shared; pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; pub use input_abstract::Input; -pub use parse_json::{JsonInput, JsonObject}; +pub use parse_json::{JsonInput, JsonObject, JsonType}; pub use return_enums::{ py_string_str, EitherBytes, EitherString, GenericArguments, GenericCollection, GenericIterator, GenericMapping, JsonArgs, PyArgs, diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs index 4f0f5c430..a35d1f054 100644 --- a/src/input/parse_json.rs +++ b/src/input/parse_json.rs @@ -2,9 +2,50 @@ use std::fmt; use indexmap::IndexMap; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::types::{PyDict, PySet}; use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; +use crate::build_tools::py_error; + +#[derive(Clone, Debug)] +pub enum JsonType { + Null = 0b10000000, + Bool = 0b01000000, + Int = 0b00100000, + Float = 0b00010000, + String = 0b00001000, + Array = 0b00000100, + Object = 0b00000010, +} + +impl JsonType { + pub fn combine(set: &PySet) -> PyResult { + set.iter().map(Self::try_from).try_fold(0u8, |a, b| Ok(a | b? as u8)) + } + + pub fn matches(&self, mask: u8) -> bool { + *self as u8 & mask > 0 + } +} + +impl TryFrom<&PyAny> for JsonType { + type Error = PyErr; + + fn try_from(value: &PyAny) -> PyResult { + let s: &str = value.extract()?; + match s { + "null" => Ok(Self::Null), + "bool" => Ok(Self::Bool), + "int" => Ok(Self::Int), + "float" => Ok(Self::Float), + "str" => Ok(Self::String), + "list" => Ok(Self::Array), + "dict" => Ok(Self::Object), + _ => py_error!("Invalid json type: {}", s), + } + } +} + /// similar to serde `Value` but with int and float split #[derive(Clone, Debug)] pub enum JsonInput { diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index f86774ee0..3b853aeb8 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -1,10 +1,10 @@ use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyType}; +use pyo3::types::{PyDict, PySet, PyType}; use crate::build_tools::SchemaDict; use crate::errors::{ErrorKind, ValError, ValResult}; -use crate::input::Input; +use crate::input::{Input, JsonType}; use crate::recursion_guard::RecursionGuard; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; @@ -12,6 +12,7 @@ use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct IsInstanceValidator { class: Py, + json_types: u8, class_repr: String, name: String, } @@ -27,8 +28,13 @@ impl BuildValidator for IsInstanceValidator { let class: &PyType = schema.get_as_req(intern!(schema.py(), "cls"))?; let class_repr = class.name()?.to_string(); let name = format!("{}[{}]", Self::EXPECTED_TYPE, class_repr); + let json_types = match schema.get_as::<&PySet>(intern!(schema.py(), "json_types"))? { + Some(s) => JsonType::combine(s)?, + None => 0, + }; Ok(Self { class: class.into(), + json_types, class_repr, name, } @@ -45,7 +51,7 @@ impl Validator for IsInstanceValidator { _slots: &'data [CombinedValidator], _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - match input.is_instance(self.class.as_ref(py))? { + match input.is_instance(self.class.as_ref(py), self.json_types)? { true => Ok(input.to_object(py)), false => Err(ValError::new( ErrorKind::IsInstanceOf { diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index c837d7559..39456e376 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1003,3 +1003,15 @@ def test_generator_rust(benchmark): assert sum(v.validate_python(input_value)) == 4950 benchmark(v.validate_python, input_value) + + +@pytest.mark.benchmark(group='isinstance-json') +def test_isinstance_json(benchmark): + validator = SchemaValidator(core_schema.is_instance_schema(str, json_types={'str'})) + assert validator.isinstance_json('"foo"') is True + assert validator.isinstance_json('123') is False + + @benchmark + def t(): + validator.isinstance_json('"foo"') + validator.isinstance_json('123') diff --git a/tests/validators/test_is_instance.py b/tests/validators/test_is_instance.py index c0add6810..6da88b4a0 100644 --- a/tests/validators/test_is_instance.py +++ b/tests/validators/test_is_instance.py @@ -1,6 +1,8 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema + +from ..conftest import plain_repr class Foo: @@ -118,3 +120,66 @@ def test_repr(): def test_is_type(input_val, value): v = SchemaValidator({'type': 'is-instance', 'cls': type}) assert v.isinstance_python(input_val) == value + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + ('null', False), + ('true', True), + ('1', False), + ('1.1', False), + ('"a string"', True), + ('["s"]', False), + ('{"s": 1}', False), + ], +) +def test_is_instance_json_string_bool(input_val, expected): + v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types={'str', 'bool'})) + assert v.isinstance_json(input_val) == expected + + +@pytest.mark.parametrize( + 'input_val,expected', + [ + ('null', False), + ('true', False), + ('1', False), + ('1.1', False), + ('"a string"', False), + ('["s"]', True), + ('{"s": 1}', False), + ], +) +def test_is_instance_json_list(input_val, expected): + v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types=('list',))) + assert v.isinstance_json(input_val) == expected + + +def test_is_instance_dict(): + v = SchemaValidator( + core_schema.dict_schema( + keys_schema=core_schema.is_instance_schema(str, json_types={'str'}), + values_schema=core_schema.is_instance_schema(int, json_types={'int', 'dict'}), + ) + ) + assert v.isinstance_python({'foo': 1}) is True + assert v.isinstance_python({1: 1}) is False + assert v.isinstance_json('{"foo": 1}') is True + assert v.isinstance_json('{"foo": "1"}') is False + assert v.isinstance_json('{"foo": {"a": 1}}') is True + + +def test_is_instance_dict_not_str(): + v = SchemaValidator(core_schema.dict_schema(keys_schema=core_schema.is_instance_schema(int, json_types={'int'}))) + assert v.isinstance_python({1: 1}) is True + assert v.isinstance_python({'foo': 1}) is False + assert v.isinstance_json('{"foo": 1}') is False + + +def test_json_mask(): + assert 'json_types:128' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types={'null'}))) + assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str))) + assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types=set()))) + v = SchemaValidator(core_schema.is_instance_schema(str, json_types={'list', 'dict'})) + assert 'json_types:6' in plain_repr(v) # 2 + 4