Skip to content

Commit

Permalink
Add ultra-strict mode for better union decisions (#536)
Browse files Browse the repository at this point in the history
* make strict float not allow an int

* adding different_strict_behavior()

* revert float test changes

* add tests

* improve testing structure

* support definition references, test

* update validate_float
  • Loading branch information
samuelcolvin authored Apr 12, 2023
1 parent dd98ccc commit 5c90851
Show file tree
Hide file tree
Showing 41 changed files with 668 additions and 21 deletions.
7 changes: 5 additions & 2 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,16 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_int()
}

fn validate_float(&self, strict: bool) -> ValResult<f64> {
if strict {
fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
if ultra_strict {
self.ultra_strict_float()
} else if strict {
self.strict_float()
} else {
self.lax_float()
}
}
fn ultra_strict_float(&self) -> ValResult<f64>;
fn strict_float(&self) -> ValResult<f64>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_float(&self) -> ValResult<f64> {
Expand Down
10 changes: 10 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ impl<'a> Input<'a> for JsonInput {
}
}

fn ultra_strict_float(&self) -> ValResult<f64> {
match self {
JsonInput::Float(f) => Ok(*f),
_ => Err(ValError::new(ErrorType::FloatType, self)),
}
}
fn strict_float(&self) -> ValResult<f64> {
match self {
JsonInput::Float(f) => Ok(*f),
Expand Down Expand Up @@ -368,6 +374,10 @@ impl<'a> Input<'a> for String {
}
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn ultra_strict_float(&self) -> ValResult<f64> {
self.strict_float()
}
#[cfg_attr(has_no_coverage, no_coverage)]
fn strict_float(&self) -> ValResult<f64> {
Err(ValError::new(ErrorType::FloatType, self))
Expand Down
14 changes: 11 additions & 3 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::str::from_utf8;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping,
PySet, PyString, PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple, PyType,
};
#[cfg(not(PyPy))]
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
Expand Down Expand Up @@ -289,6 +289,15 @@ impl<'a> Input<'a> for PyAny {
}
}

fn ultra_strict_float(&self) -> ValResult<f64> {
if matches!(self.is_instance_of::<PyInt>(), Ok(true)) {
Err(ValError::new(ErrorType::FloatType, self))
} else if let Ok(float) = self.extract::<f64>() {
Ok(float)
} else {
Err(ValError::new(ErrorType::FloatType, self))
}
}
fn strict_float(&self) -> ValResult<f64> {
if self.extract::<bool>().is_ok() {
Err(ValError::new(ErrorType::FloatType, self))
Expand All @@ -298,7 +307,6 @@ impl<'a> Input<'a> for PyAny {
Err(ValError::new(ErrorType::FloatType, self))
}
}

fn lax_float(&self) -> ValResult<f64> {
if let Ok(float) = self.extract::<f64>() {
Ok(float)
Expand Down
8 changes: 8 additions & 0 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ impl Validator for AnyValidator {
Ok(input.to_object(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
_ultra_strict: bool,
) -> bool {
false
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
10 changes: 10 additions & 0 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@ impl Validator for ArgumentsValidator {
}
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
self.parameters
.iter()
.any(|p| p.validator.different_strict_behavior(build_context, ultra_strict))
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
8 changes: 8 additions & 0 deletions src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ impl Validator for BoolValidator {
Ok(input.validate_bool(extra.strict.unwrap_or(self.strict))?.into_py(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
!ultra_strict
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
16 changes: 16 additions & 0 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl Validator for BytesValidator {
Ok(either_bytes.into_py(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
!ultra_strict
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down Expand Up @@ -91,6 +99,14 @@ impl Validator for BytesConstrainedValidator {
Ok(either_bytes.into_py(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
!ultra_strict
}

fn get_name(&self) -> &str {
"constrained-bytes"
}
Expand Down
14 changes: 14 additions & 0 deletions src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ impl Validator for CallValidator {
}
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
if let Some(return_validator) = &self.return_validator {
if return_validator.different_strict_behavior(build_context, ultra_strict) {
return true;
}
}
self.arguments_validator
.different_strict_behavior(build_context, ultra_strict)
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
8 changes: 8 additions & 0 deletions src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ impl Validator for CallableValidator {
}
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
_ultra_strict: bool,
) -> bool {
false
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
10 changes: 10 additions & 0 deletions src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ impl Validator for ChainValidator {
})
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
self.steps
.iter()
.any(|v| v.different_strict_behavior(build_context, ultra_strict))
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
8 changes: 8 additions & 0 deletions src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ impl Validator for CustomErrorValidator {
.map_err(|_| self.custom_error.as_val_error(input))
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
self.validator.different_strict_behavior(build_context, ultra_strict)
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
22 changes: 22 additions & 0 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ impl Validator for DataclassArgsValidator {
}
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
self.fields
.iter()
.any(|f| f.validator.different_strict_behavior(build_context, ultra_strict))
}

fn get_name(&self) -> &str {
&self.validator_name
}
Expand Down Expand Up @@ -510,6 +520,18 @@ impl Validator for DataclassValidator {
Ok(obj.to_object(py))
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
if ultra_strict {
self.validator.different_strict_behavior(build_context, ultra_strict)
} else {
true
}
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
8 changes: 8 additions & 0 deletions src/validators/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ impl Validator for DateValidator {
Ok(date.try_into_py(py)?)
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
!ultra_strict
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
8 changes: 8 additions & 0 deletions src/validators/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ impl Validator for DateTimeValidator {
Ok(datetime.try_into_py(py)?)
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
!ultra_strict
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down
14 changes: 14 additions & 0 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ impl Validator for DefinitionRefValidator {
}
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
if let Some(build_context) = build_context {
// have to unwrap here, because we can't return an error from this function, should be okay
let validator = build_context.find_validator(self.validator_id).unwrap();
validator.different_strict_behavior(None, ultra_strict)
} else {
false
}
}

fn get_name(&self) -> &str {
&self.inner_name
}
Expand Down
13 changes: 13 additions & 0 deletions src/validators/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ impl Validator for DictValidator {
}
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
if ultra_strict {
self.key_validator.different_strict_behavior(build_context, true)
|| self.value_validator.different_strict_behavior(build_context, true)
} else {
true
}
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
21 changes: 19 additions & 2 deletions src/validators/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,21 @@ impl Validator for FloatValidator {
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorType::FiniteNumber, input));
}
Ok(float.into_py(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
_ultra_strict: bool,
) -> bool {
true
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
Expand Down Expand Up @@ -104,7 +112,7 @@ impl Validator for ConstrainedFloatValidator {
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let float = input.validate_float(extra.strict.unwrap_or(self.strict))?;
let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
if !self.allow_inf_nan && !float.is_finite() {
return Err(ValError::new(ErrorType::FiniteNumber, input));
}
Expand Down Expand Up @@ -142,6 +150,15 @@ impl Validator for ConstrainedFloatValidator {
}
Ok(float.into_py(py))
}

fn different_strict_behavior(
&self,
_build_context: Option<&BuildContext<CombinedValidator>>,
_ultra_strict: bool,
) -> bool {
true
}

fn get_name(&self) -> &str {
"constrained-float"
}
Expand Down
15 changes: 15 additions & 0 deletions src/validators/frozenset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ impl Validator for FrozenSetValidator {
Ok(f_set.into_py(py))
}

fn different_strict_behavior(
&self,
build_context: Option<&BuildContext<CombinedValidator>>,
ultra_strict: bool,
) -> bool {
if ultra_strict {
match self.item_validator {
Some(ref v) => v.different_strict_behavior(build_context, true),
None => false,
}
} else {
true
}
}

fn get_name(&self) -> &str {
&self.name
}
Expand Down
Loading

0 comments on commit 5c90851

Please sign in to comment.