Skip to content

Commit

Permalink
Generator type (#276)
Browse files Browse the repository at this point in the history
* testing generator type

* fix imports in benchmarks

* implmenting InternalValidator properly

* cleanup and generator_schema

* TooLong check and reuse InternalValidator in functions

* customising error titles
  • Loading branch information
samuelcolvin authored Oct 4, 2022
1 parent dec03f4 commit c13283d
Show file tree
Hide file tree
Showing 17 changed files with 572 additions and 43 deletions.
22 changes: 18 additions & 4 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def callable_schema() -> CallableSchema:

class ListSchema(TypedDict, total=False):
type: Required[Literal['list']]
items_schema: CoreSchema # default: AnySchema
items_schema: CoreSchema
min_length: int
max_length: int
strict: bool
Expand Down Expand Up @@ -368,7 +368,7 @@ def tuple_variable_schema(

class SetSchema(TypedDict, total=False):
type: Required[Literal['set']]
items_schema: CoreSchema # default: AnySchema
items_schema: CoreSchema
min_length: int
max_length: int
strict: bool
Expand All @@ -390,14 +390,14 @@ def set_schema(

class FrozenSetSchema(TypedDict, total=False):
type: Required[Literal['frozenset']]
items_schema: CoreSchema # default: AnySchema
items_schema: CoreSchema
min_length: int
max_length: int
strict: bool
ref: str


def frozen_set_schema(
def frozenset_schema(
items_schema: CoreSchema | None = None,
*,
min_length: int | None = None,
Expand All @@ -415,6 +415,19 @@ def frozen_set_schema(
)


class GeneratorSchema(TypedDict, total=False):
type: Required[Literal['generator']]
items_schema: CoreSchema
max_length: int
ref: str


def generator_schema(
items_schema: CoreSchema | None = None, *, max_length: int | None = None, ref: str | None = None
) -> GeneratorSchema:
return dict_not_none(type='generator', items_schema=items_schema, max_length=max_length, ref=ref)


class DictSchema(TypedDict, total=False):
type: Required[Literal['dict']]
keys_schema: CoreSchema # default: AnySchema
Expand Down Expand Up @@ -865,6 +878,7 @@ def recursive_reference_schema(schema_ref: str) -> RecursiveReferenceSchema:
TupleVariableSchema,
SetSchema,
FrozenSetSchema,
GeneratorSchema,
DictSchema,
FunctionSchema,
FunctionWrapSchema,
Expand Down
4 changes: 3 additions & 1 deletion src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ pub enum ErrorKind {
error: String,
},
// ---------------------
// generic list-list errors
// generic collection and iteration errors
#[strum(message = "Input should be iterable")]
IterableType,
#[strum(message = "Error iterating over object")]
IterationError,
// ---------------------
Expand Down
4 changes: 3 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::input::datetime::EitherTime;

use super::datetime::{EitherDate, EitherDateTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherString};
use super::{GenericArguments, GenericCollection, GenericMapping};
use super::{GenericArguments, GenericCollection, GenericIterator, GenericMapping};

/// all types have three methods: `validate_*`, `strict_*`, `lax_*`
/// the convention is to either implement:
Expand Down Expand Up @@ -178,6 +178,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_frozenset()
}

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<EitherDate> {
if strict {
self.strict_date()
Expand Down
26 changes: 24 additions & 2 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::parse_json::JsonArray;
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
use super::{
EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericCollection, GenericMapping, Input, JsonArgs,
JsonInput,
EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericCollection, GenericIterator, GenericMapping,
Input, JsonArgs, JsonInput,
};

impl<'a> Input<'a> for JsonInput {
Expand Down Expand Up @@ -196,6 +197,19 @@ impl<'a> Input<'a> for JsonInput {
self.validate_frozenset(false)
}

fn validate_iter(&self) -> ValResult<GenericIterator> {
match self {
JsonInput::Array(a) => Ok(a.clone().into()),
JsonInput::String(s) => Ok(string_to_vec(s).into()),
JsonInput::Object(object) => {
// return keys iterator to match python's behavior
let keys: Vec<JsonInput> = object.keys().map(|k| JsonInput::String(k.clone())).collect();
Ok(keys.into())
}
_ => Err(ValError::new(ErrorKind::IterableType, self)),
}
}

fn validate_date(&self, _strict: bool) -> ValResult<EitherDate> {
match self {
JsonInput::String(v) => bytes_as_date(self, v.as_bytes()),
Expand Down Expand Up @@ -363,6 +377,10 @@ impl<'a> Input<'a> for String {
self.validate_frozenset(false)
}

fn validate_iter(&self) -> ValResult<GenericIterator> {
Ok(string_to_vec(self).into())
}

fn validate_date(&self, _strict: bool) -> ValResult<EitherDate> {
bytes_as_date(self, self.as_bytes())
}
Expand Down Expand Up @@ -395,3 +413,7 @@ impl<'a> Input<'a> for String {
self.validate_timedelta(false)
}
}

fn string_to_vec(s: &str) -> JsonArray {
s.chars().map(|c| JsonInput::String(c.to_string())).collect()
}
9 changes: 9 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
use pyo3::{intern, AsPyPointer};

use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
use crate::input::GenericIterator;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
Expand Down Expand Up @@ -441,6 +442,14 @@ impl<'a> Input<'a> for PyAny {
}
}

fn validate_iter(&self) -> ValResult<GenericIterator> {
if self.iter().is_ok() {
Ok(self.into())
} else {
Err(ValError::new(ErrorKind::IterableType, self))
}
}

fn strict_date(&self) -> ValResult<EitherDate> {
if self.cast_as::<PyDateTime>().is_ok() {
// have to check if it's a datetime first, otherwise the line below converts to a date
Expand Down
3 changes: 2 additions & 1 deletion src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
pub use input_abstract::Input;
pub use parse_json::{JsonInput, JsonObject};
pub use return_enums::{
py_string_str, EitherBytes, EitherString, GenericArguments, GenericCollection, GenericMapping, JsonArgs, PyArgs,
py_string_str, EitherBytes, EitherString, GenericArguments, GenericCollection, GenericIterator, GenericMapping,
JsonArgs, PyArgs,
};

pub fn repr_string(v: &PyAny) -> PyResult<String> {
Expand Down
89 changes: 88 additions & 1 deletion src/input/return_enums.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple};
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use crate::errors::{ErrorKind, InputValue, ValError, ValLineError, ValResult};
use crate::recursion_guard::RecursionGuard;
Expand Down Expand Up @@ -164,6 +164,93 @@ derive_from!(GenericMapping, PyDict, PyDict);
derive_from!(GenericMapping, PyGetAttr, PyAny);
derive_from!(GenericMapping, JsonObject, JsonObject);

#[derive(Debug, Clone)]
pub enum GenericIterator {
PyIterator(GenericPyIterator),
JsonArray(GenericJsonIterator),
}

impl From<JsonArray> for GenericIterator {
fn from(array: JsonArray) -> Self {
let length = array.len();
let json_iter = GenericJsonIterator {
array,
length,
index: 0,
};
Self::JsonArray(json_iter)
}
}

impl From<&PyAny> for GenericIterator {
fn from(obj: &PyAny) -> Self {
let py_iter = GenericPyIterator {
obj: obj.to_object(obj.py()),
iter: obj.iter().unwrap().into_py(obj.py()),
index: 0,
};
Self::PyIterator(py_iter)
}
}

#[derive(Debug, Clone)]
pub struct GenericPyIterator {
obj: PyObject,
iter: Py<PyIterator>,
index: usize,
}

impl GenericPyIterator {
pub fn next<'a>(&'a mut self, py: Python<'a>) -> PyResult<Option<(&'a PyAny, usize)>> {
match self.iter.as_ref(py).next() {
Some(Ok(next)) => {
let a = (next, self.index);
self.index += 1;
Ok(Some(a))
}
Some(Err(err)) => Err(err),
None => Ok(None),
}
}

pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny {
self.obj.as_ref(py)
}

pub fn index(&self) -> usize {
self.index
}
}

#[derive(Debug, Clone)]
pub struct GenericJsonIterator {
array: JsonArray,
length: usize,
index: usize,
}

impl GenericJsonIterator {
pub fn next(&mut self, _py: Python) -> PyResult<Option<(&JsonInput, usize)>> {
if self.index < self.length {
let next = unsafe { self.array.get_unchecked(self.index) };
let a = (next, self.index);
self.index += 1;
Ok(Some(a))
} else {
Ok(None)
}
}

pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny {
let input = JsonInput::Array(self.array.clone());
input.to_object(py).into_ref(py)
}

pub fn index(&self) -> usize {
self.index
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct PyArgs<'a> {
pub args: Option<&'a PyTuple>,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/frozenset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::input::{GenericCollection, Input};
use crate::recursion_guard::RecursionGuard;

use super::list::generic_collection_build;
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

#[derive(Debug, Clone)]
pub struct FrozenSetValidator {
Expand Down
29 changes: 5 additions & 24 deletions src/validators/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::input::Input;
use crate::questions::Question;
use crate::recursion_guard::RecursionGuard;

use super::generator::InternalValidator;
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

pub struct FunctionBuilder;
Expand Down Expand Up @@ -203,13 +204,7 @@ impl Validator for FunctionWrapValidator {
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let validator_kwarg = ValidatorCallable {
validator: self.validator.clone(),
slots: slots.to_vec(),
data: extra.data.map(|d| d.into_py(py)),
field: extra.field.map(|f| f.to_string()),
strict: extra.strict,
context: extra.context.map(|d| d.into_py(py)),
recursion_guard: recursion_guard.clone(),
validator: InternalValidator::new(py, "ValidatorCallable", &self.validator, slots, extra, recursion_guard),
};
let kwargs = kwargs!(
py,
Expand All @@ -236,16 +231,10 @@ impl Validator for FunctionWrapValidator {
}
}

#[pyclass]
#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
struct ValidatorCallable {
validator: Box<CombinedValidator>,
slots: Vec<CombinedValidator>,
data: Option<Py<PyDict>>,
field: Option<String>,
strict: Option<bool>,
context: Option<PyObject>,
recursion_guard: RecursionGuard,
validator: InternalValidator,
}

#[pymethods]
Expand All @@ -258,15 +247,7 @@ impl ValidatorCallable {
},
None => None,
};
let extra = Extra {
data: self.data.as_ref().map(|data| data.as_ref(py)),
field: self.field.as_deref(),
strict: self.strict,
context: self.context.as_ref().map(|data| data.as_ref(py)),
};
self.validator
.validate(py, input_value, &extra, &self.slots, &mut self.recursion_guard)
.map_err(|e| ValidationError::from_val_error(py, "Model".to_object(py), e, outer_location))
self.validator.validate(py, input_value, outer_location)
}

fn __repr__(&self) -> String {
Expand Down
Loading

0 comments on commit c13283d

Please sign in to comment.