Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::fmt::Debug;

use enum_dispatch::enum_dispatch;

use ahash::AHashSet;
use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString};
use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyList, PyString};

use crate::build_tools::{py_error, SchemaDict, SchemaError};
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError};
Expand Down Expand Up @@ -69,7 +70,10 @@ impl SchemaValidator {
.map_err(|e| SchemaError::from_val_error(py, e))?;
let schema = schema_obj.as_ref(py);

let mut build_context = BuildContext::default();
let mut used_refs = AHashSet::new();
extract_used_refs(schema, &mut used_refs)?;
let mut build_context = BuildContext::new(used_refs);

let mut validator = build_validator(schema, config, &mut build_context)?;
validator.complete(&build_context)?;
let slots = build_context.into_slots()?;
Expand Down Expand Up @@ -219,7 +223,12 @@ impl SchemaValidator {
py.run(code, None, Some(locals))?;
let self_schema: &PyDict = locals.get_as_req(intern!(py, "self_schema"))?;

let mut build_context = BuildContext::default();
let mut used_refs = AHashSet::new();
// NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references
// are used, they would need to be manually added here.
used_refs.insert("root-schema".to_string());
let mut build_context = BuildContext::new(used_refs);

let validator = match build_validator(self_schema, None, &mut build_context) {
Ok(v) => v,
Err(err) => return Err(SchemaError::new_err(format!("Error building self-schema:\n {}", err))),
Expand Down Expand Up @@ -260,26 +269,29 @@ pub trait BuildValidator: Sized {
-> PyResult<CombinedValidator>;
}

/// Logic to create a particular validator, called in the `validator_match` macro, then in turn by `build_validator`
fn build_single_validator<'a, T: BuildValidator>(
val_type: &str,
schema_dict: &'a PyDict,
config: Option<&'a PyDict>,
build_context: &mut BuildContext,
) -> PyResult<CombinedValidator> {
let py = schema_dict.py();
let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::<String>(intern!(py, "ref"))? {
let slot_id = build_context.prepare_slot(schema_ref)?;
let inner_val = T::build(schema_dict, config, build_context)
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?;
let name = inner_val.get_name().to_string();
build_context.complete_slot(slot_id, inner_val)?;
recursive::RecursiveContainerValidator::create(slot_id, name)
} else {
T::build(schema_dict, config, build_context)
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?
};
if let Some(schema_ref) = schema_dict.get_as::<String>(intern!(py, "ref"))? {
// we only want to use a RecursiveContainerValidator if the ref is actually used,
// this means refs can always be set without having an effect on the validator which is generated
// unless it's used/referenced
if build_context.ref_used(&schema_ref) {
let slot_id = build_context.prepare_slot(schema_ref)?;
let inner_val = T::build(schema_dict, config, build_context)?;
let name = inner_val.get_name().to_string();
build_context.complete_slot(slot_id, inner_val)?;
return Ok(recursive::RecursiveContainerValidator::create(slot_id, name));
}
}

Ok(val)
T::build(schema_dict, config, build_context)
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))
}

// macro to build the match statement for validator selection
Expand Down Expand Up @@ -523,10 +535,23 @@ pub trait Validator: Send + Sync + Clone + Debug {
/// and therefore can't be owned by them directly.
#[derive(Default, Clone)]
pub struct BuildContext {
used_refs: AHashSet<String>,
slots: Vec<(String, Option<CombinedValidator>)>,
}

impl BuildContext {
pub fn new(used_refs: AHashSet<String>) -> Self {
Self {
used_refs,
..Default::default()
}
}

/// check if a ref is used elsewhere in the schema
pub fn ref_used(&self, ref_: &str) -> bool {
self.used_refs.contains(ref_)
}

/// First of two part process to add a new validator slot, we add the `slot_ref` to the array, but not the
/// actual `validator`, we can't add the validator until it's build.
/// We need the `id` to build the validator, hence this two-step process.
Expand Down Expand Up @@ -584,3 +609,21 @@ impl BuildContext {
.collect()
}
}

fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet<String>) -> PyResult<()> {
if let Ok(dict) = schema.cast_as::<PyDict>() {
let py = schema.py();
if matches!(dict.get_as(intern!(py, "type")), Ok(Some("recursive-ref"))) {
refs.insert(dict.get_as_req(intern!(py, "schema_ref"))?);
} else {
for (_, value) in dict.iter() {
extract_used_refs(value, refs)?;
}
}
} else if let Ok(list) = schema.cast_as::<PyList>() {
for item in list.iter() {
extract_used_refs(item, refs)?;
}
}
Ok(())
}
41 changes: 36 additions & 5 deletions tests/validators/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic_core import SchemaError, SchemaValidator, ValidationError

from ..conftest import Err
from ..conftest import Err, plain_repr
from .test_typed_dict import Cls


Expand All @@ -19,10 +19,7 @@ def test_branch_nullable():
'sub_branch': {
'schema': {
'type': 'default',
'schema': {
'type': 'union',
'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}],
},
'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}},
'default': None,
}
},
Expand All @@ -31,6 +28,7 @@ def test_branch_nullable():
)

assert v.validate_python({'name': 'root'}) == {'name': 'root', 'sub_branch': None}
assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=Recursive(RecursiveContainerValidator')

assert v.validate_python({'name': 'root', 'sub_branch': {'name': 'b1'}}) == (
{'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': None}}
Expand All @@ -40,6 +38,14 @@ def test_branch_nullable():
)


def test_unused_ref():
v = SchemaValidator(
{'type': 'typed-dict', 'ref': 'Branch', 'fields': {'name': {'schema': 'str'}, 'other': {'schema': 'int'}}}
)
assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=TypedDict(TypedDictValidator')
assert v.validate_python({'name': 'root', 'other': '4'}) == {'name': 'root', 'other': 4}


def test_nullable_error():
v = SchemaValidator(
{
Expand Down Expand Up @@ -680,3 +686,28 @@ def test_many_uses_of_ref():

long_input = {'name': 'Anne', 'other_names': [f'p-{i}' for i in range(300)]}
assert v.validate_python(long_input) == long_input


def test_error_inside_recursive_wrapper():
with pytest.raises(SchemaError) as exc_info:
SchemaValidator(
{
'type': 'typed-dict',
'ref': 'Branch',
'fields': {
'sub_branch': {
'schema': {
'type': 'default',
'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}},
'default': None,
'default_factory': lambda x: 'foobar',
}
}
},
}
)
assert str(exc_info.value) == (
'Field "sub_branch":\n'
' SchemaError: Error building "default" validator:\n'
" SchemaError: 'default' and 'default_factory' cannot be used together"
)