Skip to content

Commit

Permalink
is_instance JSON support (#278)
Browse files Browse the repository at this point in the history
* is_instance json support

* add benchmark, remove Option from json_type

* remove defult impl of is_instance

* confirm tuples work too

* test for json string not is_instance
  • Loading branch information
samuelcolvin authored Oct 5, 2022
1 parent c13283d commit 66f2890
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 19 deletions.
4 changes: 3 additions & 1 deletion generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable
from datetime import date, datetime, time, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Type, Union
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union

from typing_extensions import get_args, get_origin, is_typeddict

Expand Down Expand Up @@ -67,6 +67,8 @@ def get_schema(obj) -> core_schema.CoreSchema:
return {'type': 'literal', 'expected': expected}
elif issubclass(origin, List):
return {'type': 'list', 'items_schema': get_schema(obj.__args__[0])}
elif issubclass(origin, Set):
return {'type': 'set', 'items_schema': get_schema(obj.__args__[0])}
elif issubclass(origin, Dict):
return {
'type': 'dict',
Expand Down
20 changes: 14 additions & 6 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, List, Optional, Type, Union, overload
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union, overload

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Protocol, Required
Expand Down Expand Up @@ -278,13 +278,21 @@ def literal_schema(*expected: Any, ref: str | None = None) -> LiteralSchema:
return dict_not_none(type='literal', expected=expected, ref=ref)


class IsInstanceSchema(TypedDict):
type: Literal['is-instance']
cls: Type[Any]
# must match input/parse_json.rs::JsonType::try_from
JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict']


class IsInstanceSchema(TypedDict, total=False):
type: Required[Literal['is-instance']]
cls: Required[Type[Any]]
json_types: Set[JsonType]
ref: str


def is_instance_schema(cls: Type[Any]) -> IsInstanceSchema:
return dict_not_none(type='is-instance', cls=cls)
def is_instance_schema(
cls: Type[Any], *, json_types: Set[JsonType] | None = None, ref: str | None = None
) -> IsInstanceSchema:
return dict_not_none(type='is-instance', cls=cls, json_types=json_types, ref=ref)


class CallableSchema(TypedDict):
Expand Down
6 changes: 4 additions & 2 deletions src/errors/location.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ impl ToPyObject for LocItem {
}
}

impl LocItem {
pub fn try_from(value: &PyAny) -> PyResult<Self> {
impl TryFrom<&PyAny> for LocItem {
type Error = PyErr;

fn try_from(value: &PyAny) -> PyResult<Self> {
if let Ok(str) = value.extract::<String>() {
Ok(str.into())
} else {
Expand Down
4 changes: 1 addition & 3 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
None
}

fn is_instance(&self, _class: &PyType) -> PyResult<bool> {
Ok(false)
}
fn is_instance(&self, class: &PyType, json_mask: u8) -> PyResult<bool>;

fn callable(&self) -> bool {
false
Expand Down
29 changes: 29 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use pyo3::prelude::*;
use pyo3::types::PyType;

use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
use crate::input::JsonType;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -30,6 +34,23 @@ impl<'a> Input<'a> for JsonInput {
matches!(self, JsonInput::Null)
}

fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
if json_mask == 0 {
Ok(false)
} else {
let json_type: JsonType = match self {
JsonInput::Null => JsonType::Null,
JsonInput::Bool(_) => JsonType::Bool,
JsonInput::Int(_) => JsonType::Int,
JsonInput::Float(_) => JsonType::Float,
JsonInput::String(_) => JsonType::String,
JsonInput::Array(_) => JsonType::Array,
JsonInput::Object(_) => JsonType::Object,
};
Ok(json_type.matches(json_mask))
}
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
match self {
JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()),
Expand Down Expand Up @@ -284,6 +305,14 @@ impl<'a> Input<'a> for String {
false
}

fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
if json_mask == 0 {
Ok(false)
} else {
Ok(JsonType::String.matches(json_mask))
}
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
Err(ValError::new(ErrorKind::ArgumentsType, self))
Expand Down
2 changes: 1 addition & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<'a> Input<'a> for PyAny {
self.getattr(name).ok()
}

fn is_instance(&self, class: &PyType) -> PyResult<bool> {
fn is_instance(&self, class: &PyType, _json_mask: u8) -> PyResult<bool> {
self.is_instance(class)
}

Expand Down
2 changes: 1 addition & 1 deletion src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod shared;

pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
pub use input_abstract::Input;
pub use parse_json::{JsonInput, JsonObject};
pub use parse_json::{JsonInput, JsonObject, JsonType};
pub use return_enums::{
py_string_str, EitherBytes, EitherString, GenericArguments, GenericCollection, GenericIterator, GenericMapping,
JsonArgs, PyArgs,
Expand Down
43 changes: 42 additions & 1 deletion src/input/parse_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,50 @@ use std::fmt;

use indexmap::IndexMap;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PySet};
use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor};

use crate::build_tools::py_error;

#[derive(Clone, Debug)]
pub enum JsonType {
Null = 0b10000000,
Bool = 0b01000000,
Int = 0b00100000,
Float = 0b00010000,
String = 0b00001000,
Array = 0b00000100,
Object = 0b00000010,
}

impl JsonType {
pub fn combine(set: &PySet) -> PyResult<u8> {
set.iter().map(Self::try_from).try_fold(0u8, |a, b| Ok(a | b? as u8))
}

pub fn matches(&self, mask: u8) -> bool {
*self as u8 & mask > 0
}
}

impl TryFrom<&PyAny> for JsonType {
type Error = PyErr;

fn try_from(value: &PyAny) -> PyResult<Self> {
let s: &str = value.extract()?;
match s {
"null" => Ok(Self::Null),
"bool" => Ok(Self::Bool),
"int" => Ok(Self::Int),
"float" => Ok(Self::Float),
"str" => Ok(Self::String),
"list" => Ok(Self::Array),
"dict" => Ok(Self::Object),
_ => py_error!("Invalid json type: {}", s),
}
}
}

/// similar to serde `Value` but with int and float split
#[derive(Clone, Debug)]
pub enum JsonInput {
Expand Down
12 changes: 9 additions & 3 deletions src/validators/is_instance.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyType};
use pyo3::types::{PyDict, PySet, PyType};

use crate::build_tools::SchemaDict;
use crate::errors::{ErrorKind, ValError, ValResult};
use crate::input::Input;
use crate::input::{Input, JsonType};
use crate::recursion_guard::RecursionGuard;

use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

#[derive(Debug, Clone)]
pub struct IsInstanceValidator {
class: Py<PyType>,
json_types: u8,
class_repr: String,
name: String,
}
Expand All @@ -27,8 +28,13 @@ impl BuildValidator for IsInstanceValidator {
let class: &PyType = schema.get_as_req(intern!(schema.py(), "cls"))?;
let class_repr = class.name()?.to_string();
let name = format!("{}[{}]", Self::EXPECTED_TYPE, class_repr);
let json_types = match schema.get_as::<&PySet>(intern!(schema.py(), "json_types"))? {
Some(s) => JsonType::combine(s)?,
None => 0,
};
Ok(Self {
class: class.into(),
json_types,
class_repr,
name,
}
Expand All @@ -45,7 +51,7 @@ impl Validator for IsInstanceValidator {
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
match input.is_instance(self.class.as_ref(py))? {
match input.is_instance(self.class.as_ref(py), self.json_types)? {
true => Ok(input.to_object(py)),
false => Err(ValError::new(
ErrorKind::IsInstanceOf {
Expand Down
12 changes: 12 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,15 @@ def test_generator_rust(benchmark):
assert sum(v.validate_python(input_value)) == 4950

benchmark(v.validate_python, input_value)


@pytest.mark.benchmark(group='isinstance-json')
def test_isinstance_json(benchmark):
validator = SchemaValidator(core_schema.is_instance_schema(str, json_types={'str'}))
assert validator.isinstance_json('"foo"') is True
assert validator.isinstance_json('123') is False

@benchmark
def t():
validator.isinstance_json('"foo"')
validator.isinstance_json('123')
67 changes: 66 additions & 1 deletion tests/validators/test_is_instance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest

from pydantic_core import SchemaError, SchemaValidator, ValidationError
from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema

from ..conftest import plain_repr


class Foo:
Expand Down Expand Up @@ -118,3 +120,66 @@ def test_repr():
def test_is_type(input_val, value):
v = SchemaValidator({'type': 'is-instance', 'cls': type})
assert v.isinstance_python(input_val) == value


@pytest.mark.parametrize(
'input_val,expected',
[
('null', False),
('true', True),
('1', False),
('1.1', False),
('"a string"', True),
('["s"]', False),
('{"s": 1}', False),
],
)
def test_is_instance_json_string_bool(input_val, expected):
v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types={'str', 'bool'}))
assert v.isinstance_json(input_val) == expected


@pytest.mark.parametrize(
'input_val,expected',
[
('null', False),
('true', False),
('1', False),
('1.1', False),
('"a string"', False),
('["s"]', True),
('{"s": 1}', False),
],
)
def test_is_instance_json_list(input_val, expected):
v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types=('list',)))
assert v.isinstance_json(input_val) == expected


def test_is_instance_dict():
v = SchemaValidator(
core_schema.dict_schema(
keys_schema=core_schema.is_instance_schema(str, json_types={'str'}),
values_schema=core_schema.is_instance_schema(int, json_types={'int', 'dict'}),
)
)
assert v.isinstance_python({'foo': 1}) is True
assert v.isinstance_python({1: 1}) is False
assert v.isinstance_json('{"foo": 1}') is True
assert v.isinstance_json('{"foo": "1"}') is False
assert v.isinstance_json('{"foo": {"a": 1}}') is True


def test_is_instance_dict_not_str():
v = SchemaValidator(core_schema.dict_schema(keys_schema=core_schema.is_instance_schema(int, json_types={'int'})))
assert v.isinstance_python({1: 1}) is True
assert v.isinstance_python({'foo': 1}) is False
assert v.isinstance_json('{"foo": 1}') is False


def test_json_mask():
assert 'json_types:128' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types={'null'})))
assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str)))
assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types=set())))
v = SchemaValidator(core_schema.is_instance_schema(str, json_types={'list', 'dict'}))
assert 'json_types:6' in plain_repr(v) # 2 + 4

0 comments on commit 66f2890

Please sign in to comment.