From 3d8295efe31bf36f3350f7bd1a0ea240cc566589 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Thu, 15 Aug 2024 15:36:54 +0100 Subject: [PATCH] Support complex numbers (#1331) --- generate_self_schema.py | 2 +- python/pydantic_core/core_schema.py | 46 ++++++ src/errors/types.rs | 5 + src/input/input_abstract.rs | 4 +- src/input/input_json.rs | 33 +++++ src/input/input_python.rs | 46 +++++- src/input/input_string.rs | 9 ++ src/input/return_enums.rs | 29 +++- src/serializers/infer.rs | 24 ++- src/serializers/ob_type.rs | 8 +- src/serializers/shared.rs | 2 + src/serializers/type_serializers/complex.rs | 95 ++++++++++++ src/serializers/type_serializers/mod.rs | 1 + src/validators/complex.rs | 75 ++++++++++ src/validators/mod.rs | 3 + tests/serializers/test_complex.py | 39 +++++ tests/test_errors.py | 10 ++ tests/test_schema_functions.py | 1 + tests/validators/test_complex.py | 154 ++++++++++++++++++++ tests/validators/test_dict.py | 23 +++ 20 files changed, 601 insertions(+), 8 deletions(-) create mode 100644 src/serializers/type_serializers/complex.rs create mode 100644 src/validators/complex.rs create mode 100644 tests/serializers/test_complex.py create mode 100644 tests/validators/test_complex.py diff --git a/generate_self_schema.py b/generate_self_schema.py index ac12062f3..aeb411c6d 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -50,7 +50,7 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901 if isinstance(obj, str): return {'type': obj} - elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal): + elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex): return {'type': obj.__name__.lower()} elif is_typeddict(obj): return type_dict_schema(obj, definitions) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 4ac24bd6c..886194cbc 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -745,6 +745,48 @@ def decimal_schema( ) +class ComplexSchema(TypedDict, total=False): + type: Required[Literal['complex']] + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def complex_schema( + *, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> ComplexSchema: + """ + Returns a schema that matches a complex value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.complex_schema() + v = SchemaValidator(schema) + assert v.validate_python('1+2j') == complex(1, 2) + assert v.validate_python(complex(1, 2)) == complex(1, 2) + ``` + + Args: + strict: Whether the value should be a complex object instance or a value that can be converted to a complex object + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='complex', + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + class StringSchema(TypedDict, total=False): type: Required[Literal['str']] pattern: Union[str, Pattern[str]] @@ -3796,6 +3838,7 @@ def definition_reference_schema( DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, + ComplexSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3851,6 +3894,7 @@ def definition_reference_schema( 'definitions', 'definition-ref', 'uuid', + 'complex', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] @@ -3956,6 +4000,8 @@ def definition_reference_schema( 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', + 'complex_type', + 'complex_str_parsing', ] diff --git a/src/errors/types.rs b/src/errors/types.rs index 8807ba129..ec129a63a 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -426,6 +426,9 @@ error_types! { DecimalWholeDigits { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, + // Complex errors + ComplexType {}, + ComplexStrParsing {}, } macro_rules! render { @@ -569,6 +572,8 @@ impl ErrorType { Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", + Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", + Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index b0e058d9b..24a2e0239 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -10,7 +10,7 @@ use crate::tools::py_err; use crate::validators::ValBytesMode; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; -use super::return_enums::{EitherBytes, EitherInt, EitherString}; +use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString}; use super::{EitherFloat, GenericIterator, ValidationMatch}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -173,6 +173,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValMatch>; + + fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValMatch>; } /// The problem to solve here is that iterating collections often returns owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 3adc36ba6..5eb505f6e 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -8,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::input::return_enums::EitherComplex; use crate::lookup_key::{LookupKey, LookupPath}; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; use crate::validators::ValBytesMode; @@ -304,6 +306,30 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } + + fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValResult>> { + match self { + JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex( + &PyString::new_bound(py, s), + self, + )?))), + JsonValue::Float(f) => { + if !strict { + Ok(ValidationMatch::lax(EitherComplex::Complex([*f, 0.0]))) + } else { + Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self)) + } + } + JsonValue::Int(f) => { + if !strict { + Ok(ValidationMatch::lax(EitherComplex::Complex([(*f) as f64, 0.0]))) + } else { + Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self)) + } + } + _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), + } + } } /// Required for JSON Object keys so the string can behave like an Input @@ -440,6 +466,13 @@ impl<'py> Input<'py> for str { ) -> ValResult>> { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } + + fn validate_complex(&self, _strict: bool, py: Python<'py>) -> ValResult>> { + Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex( + self.to_object(py).downcast_bound::(py)?, + self, + )?))) + } } impl BorrowInput<'_> for &'_ String { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 7840a825a..46c32a9de 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -5,15 +5,17 @@ use pyo3::prelude::*; use pyo3::types::PyType; use pyo3::types::{ - PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PyMapping, PySet, PyString, PyTime, PyTuple, + PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, + PyList, PyMapping, PySet, PyString, PyTime, PyTuple, }; use pyo3::PyTypeCheck; +use pyo3::PyTypeInfo; use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::validators::Exactness; use crate::validators::ValBytesMode; @@ -25,6 +27,7 @@ use super::datetime::{ EitherTime, }; use super::input_abstract::ValMatch; +use super::return_enums::EitherComplex; use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch}; use super::shared::{ decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int, @@ -598,6 +601,45 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } + + fn validate_complex<'a>(&'a self, strict: bool, py: Python<'py>) -> ValResult>> { + if let Ok(complex) = self.downcast::() { + return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned()))); + } + if strict { + return Err(ValError::new( + ErrorType::IsInstanceOf { + class: PyComplex::type_object_bound(py) + .qualname() + .and_then(|name| name.extract()) + .unwrap_or_else(|_| "complex".to_owned()), + context: None, + }, + self, + )); + } + + if let Ok(s) = self.downcast::() { + // If input is not a valid complex string, instead of telling users to correct + // the string, it makes more sense to tell them to provide any acceptable value + // since they might have just given values of some incorrect types instead + // of actually trying some complex strings. + if let Ok(c) = string_to_complex(s, self) { + return Ok(ValidationMatch::lax(EitherComplex::Py(c))); + } + } else if self.is_exact_instance_of::() { + return Ok(ValidationMatch::lax(EitherComplex::Complex([ + self.extract::().unwrap(), + 0.0, + ]))); + } else if self.is_exact_instance_of::() { + return Ok(ValidationMatch::lax(EitherComplex::Complex([ + self.extract::().unwrap() as f64, + 0.0, + ]))); + } + Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) + } } impl<'py> BorrowInput<'py> for Bound<'py, PyAny> { diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 3ef1b58ce..7adcaeb28 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -7,6 +7,7 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult} use crate::input::py_string_str; use crate::lookup_key::{LookupKey, LookupPath}; use crate::tools::safe_repr; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; use crate::validators::ValBytesMode; @@ -14,6 +15,7 @@ use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, }; use super::input_abstract::{Never, ValMatch}; +use super::return_enums::EitherComplex; use super::shared::{str_as_bool, str_as_float, str_as_int}; use super::{ Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input, @@ -225,6 +227,13 @@ impl<'py> Input<'py> for StringMapping<'py> { Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } + + fn validate_complex(&self, _strict: bool, _py: Python<'py>) -> ValResult>> { + match self { + Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), + } + } } impl<'py> BorrowInput<'py> for StringMapping<'py> { diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 6b62ac85a..593599bab 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -12,7 +12,7 @@ use pyo3::intern; use pyo3::prelude::*; #[cfg(not(PyPy))] use pyo3::types::PyFunction; -use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString}; +use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString}; use serde::{ser::Error, Serialize, Serializer}; @@ -724,3 +724,30 @@ impl ToPyObject for Int { } } } + +#[derive(Clone)] +pub enum EitherComplex<'a> { + Complex([f64; 2]), + Py(Bound<'a, PyComplex>), +} + +impl<'a> IntoPy for EitherComplex<'a> { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py), + Self::Py(c) => c.into_py(py), + } + } +} + +impl<'a> EitherComplex<'a> { + pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] { + match self { + EitherComplex::Complex(f) => *f, + EitherComplex::Py(f) => [ + f.getattr(intern!(py, "real")).unwrap().extract().unwrap(), + f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(), + ], + } + } +} diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 514a6711b..032ede46b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; +use pyo3::types::PyComplex; use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple}; use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; @@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known( } PyList::new_bound(py, items).into_py(py) } + ObType::Complex => { + let dict = value.downcast::()?; + let new_dict = PyDict::new_bound(py); + let _ = new_dict.set_item("real", dict.get_item("real")?); + let _ = new_dict.set_item("imag", dict.get_item("imag")?); + new_dict.into_py(py) + } ObType::Path => value.str()?.into_py(py), ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py), ObType::Unknown => { @@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known( ); iter.into_py(py) } + ObType::Complex => { + let dict = value.downcast::()?; + let new_dict = PyDict::new_bound(py); + let _ = new_dict.set_item("real", dict.get_item("real")?); + let _ = new_dict.set_item("imag", dict.get_item("imag")?); + new_dict.into_py(py) + } ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; @@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known( ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(Int), ObType::Bool => serialize!(bool), + ObType::Complex => { + let v = value.downcast::().map_err(py_err_se_err)?; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry(&"real", &v.real())?; + map.serialize_entry(&"imag", &v.imag())?; + map.end() + } ObType::Float | ObType::FloatSubclass => { let v = value.extract::().map_err(py_err_se_err)?; type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode) @@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>( } Ok(Cow::Owned(key_build.finish())) } - ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => { + ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => { py_err!(PyTypeError; "`{}` not valid as object key", ob_type) } ObType::Dataclass | ObType::PydanticSerializable => { diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 196cf27b7..8d20efaa9 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -1,8 +1,8 @@ use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::types::{ - PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PyNone, PySet, PyString, PyTime, PyTuple, PyType, + PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, + PyIterator, PyList, PyNone, PySet, PyString, PyTime, PyTuple, PyType, }; use pyo3::{intern, PyTypeInfo}; @@ -48,6 +48,7 @@ pub struct ObTypeLookup { pattern_object: PyObject, // uuid type uuid_object: PyObject, + complex: usize, } static TYPE_LOOKUP: GILOnceCell = GILOnceCell::new(); @@ -101,6 +102,7 @@ impl ObTypeLookup { .to_object(py), pattern_object: py.import_bound("re").unwrap().getattr("Pattern").unwrap().to_object(py), uuid_object: py.import_bound("uuid").unwrap().getattr("UUID").unwrap().to_object(py), + complex: PyComplex::type_object_raw(py) as usize, } } @@ -171,6 +173,7 @@ impl ObTypeLookup { ObType::Pattern => self.path_object.as_ptr() as usize == ob_type, ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type, ObType::Unknown => false, + ObType::Complex => self.complex == ob_type, }; if ans { @@ -426,6 +429,7 @@ pub enum ObType { Uuid, // unknown type Unknown, + Complex, } impl PartialEq for ObType { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index e7930512c..cb12f8840 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -142,6 +142,7 @@ combined_serializer! { Enum: super::type_serializers::enum_::EnumSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; Tuple: super::type_serializers::tuple::TupleSerializer; + Complex: super::type_serializers::complex::ComplexSerializer; } } @@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/complex.rs b/src/serializers/type_serializers/complex.rs new file mode 100644 index 000000000..5a525476e --- /dev/null +++ b/src/serializers/type_serializers/complex.rs @@ -0,0 +1,95 @@ +use std::borrow::Cow; + +use pyo3::prelude::*; +use pyo3::types::{PyComplex, PyDict}; + +use crate::definitions::DefinitionsBuilder; + +use super::{infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, TypeSerializer}; + +#[derive(Debug, Clone)] +pub struct ComplexSerializer {} + +impl BuildSerializer for ComplexSerializer { + const EXPECTED_TYPE: &'static str = "complex"; + fn build( + _schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + Ok(Self {}.into()) + } +} + +impl_py_gc_traverse!(ComplexSerializer {}); + +impl TypeSerializer for ComplexSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + match value.downcast::() { + Ok(py_complex) => match extra.mode { + SerMode::Json => { + let re = py_complex.real(); + let im = py_complex.imag(); + let mut s = format!("{im}j"); + if re != 0.0 { + let mut sign = ""; + if im >= 0.0 { + sign = "+"; + } + s = format!("{re}{sign}{s}"); + } + Ok(s.into_py(py)) + } + _ => Ok(value.into_py(py)), + }, + Err(_) => { + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + } + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self._invalid_as_json_key(key, extra, "complex") + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + match value.downcast::() { + Ok(py_complex) => { + let re = py_complex.real(); + let im = py_complex.imag(); + let mut s = format!("{im}j"); + if re != 0.0 { + let mut sign = ""; + if im >= 0.0 { + sign = "+"; + } + s = format!("{re}{sign}{s}"); + } + Ok(serializer.collect_str::(&s)?) + } + Err(_) => { + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + } + } + + fn get_name(&self) -> &str { + "complex" + } +} diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index da36f0bc1..dabd006a3 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -1,5 +1,6 @@ pub mod any; pub mod bytes; +pub mod complex; pub mod dataclass; pub mod datetime_etc; pub mod decimal; diff --git a/src/validators/complex.rs b/src/validators/complex.rs new file mode 100644 index 000000000..d1d9f6c35 --- /dev/null +++ b/src/validators/complex.rs @@ -0,0 +1,75 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; +use pyo3::types::{PyComplex, PyDict, PyString, PyType}; + +use crate::build_tools::is_strict; +use crate::errors::{ErrorTypeDefaults, ToErrorValue, ValError, ValResult}; +use crate::input::Input; + +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; + +static COMPLEX_TYPE: GILOnceCell> = GILOnceCell::new(); + +pub fn get_complex_type(py: Python) -> &Bound<'_, PyType> { + COMPLEX_TYPE + .get_or_init(py, || py.get_type_bound::().into()) + .bind(py) +} + +#[derive(Debug)] +pub struct ComplexValidator { + strict: bool, +} + +impl BuildValidator for ComplexValidator { + const EXPECTED_TYPE: &'static str = "complex"; + fn build( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + Ok(Self { + strict: is_strict(schema, config)?, + } + .into()) + } +} + +impl_py_gc_traverse!(ComplexValidator {}); + +impl Validator for ComplexValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let res = input.validate_complex(self.strict, py)?.unpack(state); + Ok(res.into_py(py)) + } + + fn get_name(&self) -> &str { + "complex" + } +} + +pub(crate) fn string_to_complex<'py>( + arg: &Bound<'py, PyString>, + input: impl ToErrorValue, +) -> ValResult> { + let py = arg.py(); + Ok(get_complex_type(py) + .call1((arg,)) + .map_err(|err| { + // Since arg is a string, the only possible error here is ValueError + // triggered by invalid complex strings and thus only this case is handled. + if err.is_instance_of::(py) { + ValError::new(ErrorTypeDefaults::ComplexStrParsing, input) + } else { + ValError::InternalErr(err) + } + })? + .downcast::()? + .to_owned()) +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 18c947313..5f88d5dc8 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -25,6 +25,7 @@ mod bytes; mod call; mod callable; mod chain; +pub(crate) mod complex; mod config; mod custom_error; mod dataclass; @@ -582,6 +583,7 @@ pub fn build_validator( // recursive (self-referencing) models definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, + complex::ComplexValidator, ) } @@ -735,6 +737,7 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + Complex(complex::ComplexValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/tests/serializers/test_complex.py b/tests/serializers/test_complex.py new file mode 100644 index 000000000..e7c98246a --- /dev/null +++ b/tests/serializers/test_complex.py @@ -0,0 +1,39 @@ +import math + +import pytest + +from pydantic_core import SchemaSerializer, core_schema + + +@pytest.mark.parametrize( + 'value,expected', + [ + (complex(-1.23e-4, 567.89), '-0.000123+567.89j'), + (complex(0, -1.23), '-1.23j'), + (complex(1.5, 0), '1.5+0j'), + (complex(1, 2), '1+2j'), + (complex(0, 1), '1j'), + (complex(0, 1e-500), '0j'), + (complex(-float('inf'), 2), '-inf+2j'), + (complex(float('inf'), 2), 'inf+2j'), + (complex(float('nan'), 2), 'NaN+2j'), + ], +) +def test_complex_json(value, expected): + v = SchemaSerializer(core_schema.complex_schema()) + c = v.to_python(value) + c_json = v.to_python(value, mode='json') + json_str = v.to_json(value).decode() + + assert c_json == expected + assert json_str == f'"{expected}"' + + if math.isnan(value.imag): + assert math.isnan(c.imag) + else: + assert c.imag == value.imag + + if math.isnan(value.real): + assert math.isnan(c.real) + else: + assert c.imag == value.imag diff --git a/tests/test_errors.py b/tests/test_errors.py index bd6f6214e..b8265f04e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -395,6 +395,16 @@ def f(input_value, info): 'Decimal input should have no more than 1 digit before the decimal point', {'whole_digits': 1}, ), + ( + 'complex_type', + 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', + None, + ), + ( + 'complex_str_parsing', + 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', + None, + ), ] diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 0971f0b9a..6d96678c8 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -290,6 +290,7 @@ def args(*args, **kwargs): (core_schema.uuid_schema, args(), {'type': 'uuid'}), (core_schema.decimal_schema, args(), {'type': 'decimal'}), (core_schema.decimal_schema, args(multiple_of=5, gt=1.2), {'type': 'decimal', 'multiple_of': 5, 'gt': 1.2}), + (core_schema.complex_schema, args(), {'type': 'complex'}), ] diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py new file mode 100644 index 000000000..83c5d416d --- /dev/null +++ b/tests/validators/test_complex.py @@ -0,0 +1,154 @@ +import math +import platform +import re + +import pytest + +from pydantic_core import SchemaValidator, ValidationError + +from ..conftest import Err + +EXPECTED_PARSE_ERROR_MESSAGE = 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' +EXPECTED_TYPE_ERROR_MESSAGE = 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' +EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE = 'Input should be an instance of complex' + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (complex(2, 4), complex(2, 4)), + ('2', complex(2, 0)), + ('2j', complex(0, 2)), + ('+1.23e-4-5.67e+8J', complex(1.23e-4, -5.67e8)), + ('1.5-j', complex(1.5, -1)), + ('-j', complex(0, -1)), + ('j', complex(0, 1)), + (3, complex(3, 0)), + (2.0, complex(2, 0)), + ('1e-700j', complex(0, 0)), + ('', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ('\t( -1.23+4.5J \n', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ('foobar', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ([], Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ([('x', 'y')], Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ((), Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ((('x', 'y'),), Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ( + (type('Foobar', (), {'x': 1})()), + Err(EXPECTED_TYPE_ERROR_MESSAGE), + ), + ], + ids=repr, +) +def test_complex_cases(input_value, expected): + v = SchemaValidator({'type': 'complex'}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + assert v.validate_python(input_value) == expected + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (complex(2, 4), complex(2, 4)), + ('2', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('2j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('+1.23e-4-5.67e+8J', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('1.5-j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('-j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + (3, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + (2.0, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('1e-700j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('\t( -1.23+4.5J \n', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('foobar', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ], + ids=repr, +) +def test_complex_strict(input_value, expected): + v = SchemaValidator({'type': 'complex', 'strict': True}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + assert v.validate_python(input_value) == expected + + +@pytest.mark.xfail( + platform.python_implementation() == 'PyPy', + reason='PyPy cannot process this string due to a bug, even if this string is considered valid in python', +) +def test_valid_complex_string_with_space(): + v = SchemaValidator({'type': 'complex'}) + assert v.validate_python('\t( -1.23+4.5J )\n') == complex(-1.23, 4.5) + + +def test_nan_inf_complex(): + v = SchemaValidator({'type': 'complex'}) + c = v.validate_python('NaN+Infinityj') + # c != complex(float('nan'), float('inf')) as nan != nan, + # so we need to examine the values individually + assert math.isnan(c.real) + assert math.isinf(c.imag) + + +def test_overflow_complex(): + # Python simply converts too large float values to inf, so these strings + # are still valid, even if the numbers are out of range + v = SchemaValidator({'type': 'complex'}) + + c = v.validate_python('5e600j') + assert math.isinf(c.imag) + + c = v.validate_python('-5e600j') + assert math.isinf(c.imag) + + +def test_json_complex(): + v = SchemaValidator({'type': 'complex'}) + assert v.validate_json('"-1.23e+4+5.67e-8J"') == complex(-1.23e4, 5.67e-8) + assert v.validate_json('1') == complex(1, 0) + assert v.validate_json('1.0') == complex(1, 0) + # "1" is a valid complex string + assert v.validate_json('"1"') == complex(1, 0) + + with pytest.raises(ValidationError) as exc_info: + v.validate_json('{"real": 2, "imag": 4}') + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'complex_type', + 'loc': (), + 'msg': EXPECTED_TYPE_ERROR_MESSAGE, + 'input': {'real': 2, 'imag': 4}, + } + ] + + +def test_json_complex_strict(): + v = SchemaValidator({'type': 'complex', 'strict': True}) + assert v.validate_json('"-1.23e+4+5.67e-8J"') == complex(-1.23e4, 5.67e-8) + # "1" is a valid complex string + assert v.validate_json('"1"') == complex(1, 0) + + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): + v.validate_json('1') + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): + v.validate_json('1.0') + with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): + v.validate_json('{"real": 2, "imag": 4}') + + +def test_string_complex(): + v = SchemaValidator({'type': 'complex'}) + assert v.validate_strings('+1.23e-4-5.67e+8J') == complex(1.23e-4, -5.67e8) + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): + v.validate_strings("{'real': 1, 'imag': 0}") diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index 7b3ca19c8..4057ce76e 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -235,3 +235,26 @@ def test_json_dict(): assert exc_info.value.errors(include_url=False) == [ {'type': 'dict_type', 'loc': (), 'msg': 'Input should be an object', 'input': 1} ] + + +def test_dict_complex_key(): + v = SchemaValidator( + {'type': 'dict', 'keys_schema': {'type': 'complex', 'strict': True}, 'values_schema': {'type': 'str'}} + ) + assert v.validate_python({complex(1, 2): '1'}) == {complex(1, 2): '1'} + with pytest.raises(ValidationError, match='Input should be an instance of complex'): + assert v.validate_python({'1+2j': b'1'}) == {complex(1, 2): '1'} + + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'str'}}) + with pytest.raises( + ValidationError, match='Input should be a valid python complex object, a number, or a valid complex string' + ): + v.validate_python({'1+2ja': b'1'}) + + +def test_json_dict_complex_key(): + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'int'}}) + assert v.validate_json('{"1+2j": 2, "-3": 4}') == {complex(1, 2): 2, complex(-3, 0): 4} + assert v.validate_json('{"1+2j": 2, "infj": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} + with pytest.raises(ValidationError, match='Input should be a valid complex string'): + v.validate_json('{"1+2j": 2, "": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4}