diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index ddbdff2d2..897fc0f37 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -129,13 +129,16 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_int() } - fn validate_float(&self, strict: bool) -> ValResult { - if strict { + fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult { + if ultra_strict { + self.ultra_strict_float() + } else if strict { self.strict_float() } else { self.lax_float() } } + fn ultra_strict_float(&self) -> ValResult; fn strict_float(&self) -> ValResult; #[cfg_attr(has_no_coverage, no_coverage)] fn lax_float(&self) -> ValResult { diff --git a/src/input/input_json.rs b/src/input/input_json.rs index a9c3b9779..bb3327296 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -141,6 +141,12 @@ impl<'a> Input<'a> for JsonInput { } } + fn ultra_strict_float(&self) -> ValResult { + match self { + JsonInput::Float(f) => Ok(*f), + _ => Err(ValError::new(ErrorType::FloatType, self)), + } + } fn strict_float(&self) -> ValResult { match self { JsonInput::Float(f) => Ok(*f), @@ -368,6 +374,10 @@ impl<'a> Input<'a> for String { } } + #[cfg_attr(has_no_coverage, no_coverage)] + fn ultra_strict_float(&self) -> ValResult { + self.strict_float() + } #[cfg_attr(has_no_coverage, no_coverage)] fn strict_float(&self) -> ValResult { Err(ValError::new(ErrorType::FloatType, self)) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 37bd0d30e..3f1e24967 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -4,8 +4,8 @@ use std::str::from_utf8; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{ - PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, - PySet, PyString, PyTime, PyTuple, PyType, + PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList, + PyMapping, PySet, PyString, PyTime, PyTuple, PyType, }; #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; @@ -289,6 +289,15 @@ impl<'a> Input<'a> for PyAny { } } + fn ultra_strict_float(&self) -> ValResult { + if matches!(self.is_instance_of::(), Ok(true)) { + Err(ValError::new(ErrorType::FloatType, self)) + } else if let Ok(float) = self.extract::() { + Ok(float) + } else { + Err(ValError::new(ErrorType::FloatType, self)) + } + } fn strict_float(&self) -> ValResult { if self.extract::().is_ok() { Err(ValError::new(ErrorType::FloatType, self)) @@ -298,7 +307,6 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorType::FloatType, self)) } } - fn lax_float(&self) -> ValResult { if let Ok(float) = self.extract::() { Ok(float) diff --git a/src/validators/any.rs b/src/validators/any.rs index 02bb4218f..9bcb0ab22 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -36,6 +36,14 @@ impl Validator for AnyValidator { Ok(input.to_object(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + false + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index b5ee544ff..7215da050 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -326,6 +326,16 @@ impl Validator for ArgumentsValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.parameters + .iter() + .any(|p| p.validator.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 51fcadce5..4e6dee141 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -42,6 +42,14 @@ impl Validator for BoolValidator { Ok(input.validate_bool(extra.strict.unwrap_or(self.strict))?.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 53abcc25b..632d0d116 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -49,6 +49,14 @@ impl Validator for BytesValidator { Ok(either_bytes.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } @@ -91,6 +99,14 @@ impl Validator for BytesConstrainedValidator { Ok(either_bytes.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { "constrained-bytes" } diff --git a/src/validators/call.rs b/src/validators/call.rs index fa92643d6..0632288a7 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -81,6 +81,20 @@ impl Validator for CallValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if let Some(return_validator) = &self.return_validator { + if return_validator.different_strict_behavior(build_context, ultra_strict) { + return true; + } + } + self.arguments_validator + .different_strict_behavior(build_context, ultra_strict) + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index d0b4252c4..369239dc4 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -37,6 +37,14 @@ impl Validator for CallableValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + false + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index 254f9d3fd..7f3275cec 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -85,6 +85,16 @@ impl Validator for ChainValidator { }) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.steps + .iter() + .any(|v| v.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 495ba2818..c13194d72 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -97,6 +97,14 @@ impl Validator for CustomErrorValidator { .map_err(|_| self.custom_error.as_val_error(input)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.validator.different_strict_behavior(build_context, ultra_strict) + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 005de082c..cd192fd85 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -380,6 +380,16 @@ impl Validator for DataclassArgsValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.fields + .iter() + .any(|f| f.validator.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { &self.validator_name } @@ -510,6 +520,18 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.validator.different_strict_behavior(build_context, ultra_strict) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/date.rs b/src/validators/date.rs index 97ca12c57..e29c92fa2 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -98,6 +98,14 @@ impl Validator for DateValidator { Ok(date.try_into_py(py)?) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 3e6bcf566..979e9adb7 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -102,6 +102,14 @@ impl Validator for DateTimeValidator { Ok(datetime.try_into_py(py)?) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index d1e3ce87b..1566b433f 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -107,6 +107,20 @@ impl Validator for DefinitionRefValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if let Some(build_context) = build_context { + // have to unwrap here, because we can't return an error from this function, should be okay + let validator = build_context.find_validator(self.validator_id).unwrap(); + validator.different_strict_behavior(None, ultra_strict) + } else { + false + } + } + fn get_name(&self) -> &str { &self.inner_name } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 696472add..b97de0784 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -80,6 +80,19 @@ impl Validator for DictValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.key_validator.different_strict_behavior(build_context, true) + || self.value_validator.different_strict_behavior(build_context, true) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/float.rs b/src/validators/float.rs index 97ce9313f..7b31a51ca 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -68,13 +68,21 @@ impl Validator for FloatValidator { _slots: &'data [CombinedValidator], _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let float = input.validate_float(extra.strict.unwrap_or(self.strict))?; + let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?; if !self.allow_inf_nan && !float.is_finite() { return Err(ValError::new(ErrorType::FiniteNumber, input)); } Ok(float.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + true + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } @@ -104,7 +112,7 @@ impl Validator for ConstrainedFloatValidator { _slots: &'data [CombinedValidator], _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let float = input.validate_float(extra.strict.unwrap_or(self.strict))?; + let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?; if !self.allow_inf_nan && !float.is_finite() { return Err(ValError::new(ErrorType::FiniteNumber, input)); } @@ -142,6 +150,15 @@ impl Validator for ConstrainedFloatValidator { } Ok(float.into_py(py)) } + + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + true + } + fn get_name(&self) -> &str { "constrained-float" } diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index 2fb5f0e1f..0afb3bff2 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -60,6 +60,21 @@ impl Validator for FrozenSetValidator { Ok(f_set.into_py(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + match self.item_validator { + Some(ref v) => v.different_strict_behavior(build_context, true), + None => false, + } + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/function.rs b/src/validators/function.rs index 0b63cb592..47b41c5bf 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -94,6 +94,19 @@ macro_rules! impl_validator { self._validate(validate, py, obj, extra) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.validator + .different_strict_behavior(build_context, ultra_strict) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } @@ -226,6 +239,15 @@ impl Validator for FunctionPlainValidator { r.map_err(|e| convert_err(py, e, input)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + // best guess, should we change this? + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } @@ -303,6 +325,18 @@ impl Validator for FunctionWrapValidator { self._validate(Py::new(py, handler)?.into_ref(py), py, obj, extra) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.validator.different_strict_behavior(build_context, ultra_strict) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 37bce0d8f..9b8f9d03f 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -68,6 +68,18 @@ impl Validator for GeneratorValidator { Ok(v_iterator.into_py(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if let Some(ref v) = self.item_validator { + v.different_strict_behavior(build_context, ultra_strict) + } else { + false + } + } + fn get_name(&self) -> &str { &self.name } @@ -238,6 +250,7 @@ impl InternalValidator { let extra = Extra { data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, + ultra_strict: false, context: self.context.as_ref().map(|data| data.as_ref(py)), field_name: None, self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), @@ -269,6 +282,7 @@ impl InternalValidator { let extra = Extra { data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, + ultra_strict: false, context: self.context.as_ref().map(|data| data.as_ref(py)), field_name: None, self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), diff --git a/src/validators/int.rs b/src/validators/int.rs index c48b4b183..390cf95bd 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -51,6 +51,14 @@ impl Validator for IntValidator { Ok(input.validate_int(extra.strict.unwrap_or(self.strict))?.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } @@ -113,6 +121,14 @@ impl Validator for ConstrainedIntValidator { Ok(int.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { "constrained-int" } diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index cc2f4f505..21b442998 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -90,6 +90,14 @@ impl Validator for IsInstanceValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + false + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index 49b4de093..071c864f8 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -61,6 +61,14 @@ impl Validator for IsSubclassValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + false + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/json.rs b/src/validators/json.rs index 2a10efc92..390b5b372 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -62,6 +62,18 @@ impl Validator for JsonValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if let Some(ref v) = self.validator { + v.different_strict_behavior(build_context, ultra_strict) + } else { + false + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 63eea325b..3e29d3fd1 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -64,6 +64,18 @@ impl Validator for LaxOrStrictValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.strict_validator.different_strict_behavior(build_context, true) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/list.rs b/src/validators/list.rs index e35779b79..fde7818da 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -128,6 +128,21 @@ impl Validator for ListValidator { Ok(output.into_py(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + match self.item_validator { + Some(ref v) => v.different_strict_behavior(build_context, true), + None => false, + } + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 0a5d128c9..ad0dbf756 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -88,6 +88,14 @@ impl Validator for LiteralSingleStringValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } @@ -134,6 +142,14 @@ impl Validator for LiteralSingleIntValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + true + } + fn get_name(&self) -> &str { &self.name } @@ -193,6 +209,14 @@ impl Validator for LiteralMultipleStringsValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } @@ -252,6 +276,14 @@ impl Validator for LiteralMultipleIntsValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + true + } + fn get_name(&self) -> &str { &self.name } @@ -338,6 +370,14 @@ impl Validator for LiteralGeneralValidator { )) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index c78c477d9..f42ed8510 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -179,6 +179,7 @@ impl SchemaValidator { let extra = Extra { data: None, strict, + ultra_strict: false, context, field_name: None, self_instance: None, @@ -462,6 +463,8 @@ pub struct Extra<'a> { pub field_name: Option<&'a str>, /// whether we're in strict or lax mode pub strict: Option, + /// whether we're in ultra-strict mode, only used occasionally in unions + pub ultra_strict: bool, /// context used in validator functions pub context: Option<&'a PyAny>, /// This is an instance of the model or dataclass being validated, when validation is performed from `__init__` @@ -480,10 +483,11 @@ impl<'a> Extra<'a> { } impl<'a> Extra<'a> { - pub fn as_strict(&self) -> Self { + pub fn as_strict(&self, ultra_strict: bool) -> Self { Self { data: self.data, strict: Some(true), + ultra_strict, context: self.context, field_name: self.field_name, self_instance: self.self_instance, @@ -628,6 +632,14 @@ pub trait Validator: Send + Sync + Clone + Debug { Err(py_err.into()) } + /// whether the validator behaves differently in strict mode, and in ultra strict mode + /// implementations should return true if any of their sub-validators return true + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool; + /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; diff --git a/src/validators/model.rs b/src/validators/model.rs index d565f9d0f..c92b8ee1c 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -162,6 +162,18 @@ impl Validator for ModelValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + self.validator.different_strict_behavior(build_context, ultra_strict) + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/none.rs b/src/validators/none.rs index 1720b47ea..cf2428bc5 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -37,6 +37,14 @@ impl Validator for NoneValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + _ultra_strict: bool, + ) -> bool { + false + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 997249108..402470579 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -46,6 +46,14 @@ impl Validator for NullableValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.validator.different_strict_behavior(build_context, ultra_strict) + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/set.rs b/src/validators/set.rs index 2ca290ae9..e16a6f08b 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -90,6 +90,21 @@ impl Validator for SetValidator { Ok(set.into_py(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + match self.item_validator { + Some(ref v) => v.different_strict_behavior(build_context, true), + None => false, + } + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/string.rs b/src/validators/string.rs index 5ebd08ce0..41efe669c 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -48,6 +48,14 @@ impl Validator for StrValidator { Ok(input.validate_str(extra.strict.unwrap_or(self.strict))?.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } @@ -119,6 +127,14 @@ impl Validator for StrConstrainedValidator { Ok(py_string.into_py(py)) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { "constrained-str" } diff --git a/src/validators/time.rs b/src/validators/time.rs index e535c381f..65b8013d5 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -90,6 +90,14 @@ impl Validator for TimeValidator { Ok(time.try_into_py(py)?) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 362328523..6bf2a3f15 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -89,6 +89,14 @@ impl Validator for TimeDeltaValidator { Ok(timedelta.try_into_py(py)?) } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index c2b03ecbf..cbdaff221 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -76,6 +76,21 @@ impl Validator for TupleVariableValidator { Ok(PyTuple::new(py, &output).into_py(py)) } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + match self.item_validator { + Some(ref v) => v.different_strict_behavior(build_context, true), + None => false, + } + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } @@ -223,6 +238,28 @@ impl Validator for TuplePositionalValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + if ultra_strict { + if self + .items_validators + .iter() + .any(|v| v.different_strict_behavior(build_context, true)) + { + true + } else if let Some(ref v) = self.extra_validator { + v.different_strict_behavior(build_context, true) + } else { + false + } + } else { + true + } + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 2139d0e76..64888949e 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -386,6 +386,16 @@ impl Validator for TypedDictValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.fields + .iter() + .any(|f| f.validator.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/union.rs b/src/validators/union.rs index b9495e39b..1f6fd7515 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -25,6 +25,8 @@ pub struct UnionValidator { custom_error: Option, strict: bool, name: String, + strict_required: bool, + ultra_strict_required: bool, } impl BuildValidator for UnionValidator { @@ -54,6 +56,8 @@ impl BuildValidator for UnionValidator { custom_error: CustomError::build(schema, config, build_context)?, strict: is_strict(schema, config)?, name: format!("{}[{descr}]", Self::EXPECTED_TYPE), + strict_required: true, + ultra_strict_required: false, } .into()) } @@ -84,12 +88,25 @@ impl Validator for UnionValidator { slots: &'data [CombinedValidator], recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { + if self.ultra_strict_required { + // do an ultra strict check first + let ultra_strict_extra = extra.as_strict(true); + if let Some(res) = self + .choices + .iter() + .map(|validator| validator.validate(py, input, &ultra_strict_extra, slots, recursion_guard)) + .find(ValResult::is_ok) + { + return res; + } + } + if extra.strict.unwrap_or(self.strict) { let mut errors: Option> = match self.custom_error { None => Some(Vec::with_capacity(self.choices.len())), _ => None, }; - let strict_extra = extra.as_strict(); + let strict_extra = extra.as_strict(false); for validator in &self.choices { let line_errors = match validator.validate(py, input, &strict_extra, slots, recursion_guard) { @@ -108,16 +125,18 @@ impl Validator for UnionValidator { Err(self.or_custom_error(errors, input)) } else { - // 1st pass: check if the value is an exact instance of one of the Union types, - // e.g. use validate in strict mode - let strict_extra = extra.as_strict(); - if let Some(res) = self - .choices - .iter() - .map(|validator| validator.validate(py, input, &strict_extra, slots, recursion_guard)) - .find(ValResult::is_ok) - { - return res; + if self.strict_required { + // 1st pass: check if the value is an exact instance of one of the Union types, + // e.g. use validate in strict mode + let strict_extra = extra.as_strict(false); + if let Some(res) = self + .choices + .iter() + .map(|validator| validator.validate(py, input, &strict_extra, slots, recursion_guard)) + .find(ValResult::is_ok) + { + return res; + } } let mut errors: Option> = match self.custom_error { @@ -145,6 +164,16 @@ impl Validator for UnionValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.choices + .iter() + .any(|v| v.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { &self.name } @@ -154,7 +183,10 @@ impl Validator for UnionValidator { } fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> { - self.choices.iter_mut().try_for_each(|v| v.complete(build_context)) + self.choices.iter_mut().try_for_each(|v| v.complete(build_context))?; + self.strict_required = self.different_strict_behavior(Some(build_context), false); + self.ultra_strict_required = self.different_strict_behavior(Some(build_context), true); + Ok(()) } } @@ -391,6 +423,16 @@ impl Validator for TaggedUnionValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.choices + .values() + .any(|v| v.different_strict_behavior(build_context, ultra_strict)) + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/url.rs b/src/validators/url.rs index c82699ade..6a1a4c9e3 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -86,6 +86,14 @@ impl Validator for UrlValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } @@ -206,6 +214,14 @@ impl Validator for MultiHostUrlValidator { } } + fn different_strict_behavior( + &self, + _build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + !ultra_strict + } + fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index ceb4986b7..dac578b8d 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -145,6 +145,14 @@ impl Validator for WithDefaultValidator { } } + fn different_strict_behavior( + &self, + build_context: Option<&BuildContext>, + ultra_strict: bool, + ) -> bool { + self.validator.different_strict_behavior(build_context, ultra_strict) + } + fn get_name(&self) -> &str { &self.name } diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index d0bb88419..8bb6a8256 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,6 +1,7 @@ import pytest +from dirty_equals import IsFloat, IsInt -from pydantic_core import SchemaError, SchemaValidator, ValidationError +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema from ..conftest import plain_repr @@ -326,3 +327,90 @@ def test_custom_error_type_context(): assert exc_info.value.errors() == [ {'type': 'less_than', 'loc': (), 'msg': 'Input should be less than 42', 'input': 123, 'ctx': {'lt': 42.0}} ] + + +def test_dirty_behaviour(): + """ + Check dirty-equals does what we expect. + """ + + assert 1 == IsInt(approx=1, delta=0) + assert 1.0 != IsInt(approx=1, delta=0) + assert 1 != IsFloat(approx=1, delta=0) + assert 1.0 == IsFloat(approx=1, delta=0) + + +def test_int_float(): + v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.float_schema()])) + assert 'strict_required:true' in plain_repr(v) + assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour + + assert v.validate_python(1) == IsInt(approx=1, delta=0) + assert v.validate_json('1') == IsInt(approx=1, delta=0) + assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) + assert v.validate_json('1.0') == IsFloat(approx=1, delta=0) + + v = SchemaValidator(core_schema.union_schema([core_schema.float_schema(), core_schema.int_schema()])) + assert v.validate_python(1) == IsInt(approx=1, delta=0) + assert v.validate_json('1') == IsInt(approx=1, delta=0) + assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) + assert v.validate_json('1.0') == IsFloat(approx=1, delta=0) + + +def test_str_float(): + v = SchemaValidator(core_schema.union_schema([core_schema.str_schema(), core_schema.float_schema()])) + + assert v.validate_python(1) == IsFloat(approx=1, delta=0) + assert v.validate_json('1') == IsFloat(approx=1, delta=0) + assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) + assert v.validate_json('1.0') == IsFloat(approx=1, delta=0) + + assert v.validate_python('1.0') == '1.0' + assert v.validate_python('1') == '1' + assert v.validate_json('"1.0"') == '1.0' + assert v.validate_json('"1"') == '1' + + v = SchemaValidator(core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()])) + assert v.validate_python(1) == IsFloat(approx=1, delta=0) + assert v.validate_json('1') == IsFloat(approx=1, delta=0) + assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) + assert v.validate_json('1.0') == IsFloat(approx=1, delta=0) + + assert v.validate_python('1.0') == '1.0' + assert v.validate_python('1') == '1' + assert v.validate_json('"1.0"') == '1.0' + assert v.validate_json('"1"') == '1' + + +def test_strict_check(): + v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.json_schema()])) + assert 'strict_required:true' in plain_repr(v) + assert 'ultra_strict_required:false' in plain_repr(v) + + +def test_no_strict_check(): + v = SchemaValidator(core_schema.union_schema([core_schema.is_instance_schema(int), core_schema.json_schema()])) + assert 'strict_required:false' in plain_repr(v) + assert 'ultra_strict_required:false' in plain_repr(v) + + assert v.validate_python(123) == 123 + assert v.validate_python('[1, 2, 3]') == [1, 2, 3] + + +def test_strict_reference(): + v = SchemaValidator( + core_schema.tuple_positional_schema( + [ + core_schema.float_schema(), + core_schema.union_schema( + [core_schema.int_schema(), core_schema.definition_reference_schema('tuple-ref')] + ), + ], + ref='tuple-ref', + ) + ) + assert 'strict_required:true' in plain_repr(v) + assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour + + assert repr(v.validate_python((1, 2))) == '(1.0, 2)' + assert repr(v.validate_python((1.0, (2.0, 3)))) == '(1.0, (2.0, 3))'