Skip to content

Commit

Permalink
Support complex numbers (pydantic#1331)
Browse files Browse the repository at this point in the history
  • Loading branch information
changhc authored Aug 15, 2024
1 parent bb67044 commit 3d8295e
Show file tree
Hide file tree
Showing 20 changed files with 601 additions and 8 deletions.
2 changes: 1 addition & 1 deletion generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj, definitions)
Expand Down
46 changes: 46 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,48 @@ def decimal_schema(
)


class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
strict: bool
ref: str
metadata: Any
serialization: SerSchema


def complex_schema(
*,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> ComplexSchema:
"""
Returns a schema that matches a complex value, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.complex_schema()
v = SchemaValidator(schema)
assert v.validate_python('1+2j') == complex(1, 2)
assert v.validate_python(complex(1, 2)) == complex(1, 2)
```
Args:
strict: Whether the value should be a complex object instance or a value that can be converted to a complex object
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='complex',
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: Union[str, Pattern[str]]
Expand Down Expand Up @@ -3796,6 +3838,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3851,6 +3894,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'complex',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down Expand Up @@ -3956,6 +4000,8 @@ def definition_reference_schema(
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
'complex_type',
'complex_str_parsing',
]


Expand Down
5 changes: 5 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Complex errors
ComplexType {},
ComplexStrParsing {},
}

macro_rules! render {
Expand Down Expand Up @@ -569,6 +572,8 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex",
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::tools::py_err;
use crate::validators::ValBytesMode;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
use super::{EitherFloat, GenericIterator, ValidationMatch};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -173,6 +173,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValMatch<EitherTimedelta<'py>>;

fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValMatch<EitherComplex<'py>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand Down
33 changes: 33 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

Expand Down Expand Up @@ -304,6 +306,30 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
_ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
&PyString::new_bound(py, s),
self,
)?))),
JsonValue::Float(f) => {
if !strict {
Ok(ValidationMatch::lax(EitherComplex::Complex([*f, 0.0])))
} else {
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
}
}
JsonValue::Int(f) => {
if !strict {
Ok(ValidationMatch::lax(EitherComplex::Complex([(*f) as f64, 0.0])))
} else {
Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self))
}
}
_ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

/// Required for JSON Object keys so the string can behave like an Input
Expand Down Expand Up @@ -440,6 +466,13 @@ impl<'py> Input<'py> for str {
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>> {
bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
}

fn validate_complex(&self, _strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
self.to_object(py).downcast_bound::<PyString>(py)?,
self,
)?)))
}
}

impl BorrowInput<'_> for &'_ String {
Expand Down
46 changes: 44 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ use pyo3::prelude::*;

use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple,
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
PyList, PyMapping, PySet, PyString, PyTime, PyTuple,
};

use pyo3::PyTypeCheck;
use pyo3::PyTypeInfo;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::ValBytesMode;
Expand All @@ -25,6 +27,7 @@ use super::datetime::{
EitherTime,
};
use super::input_abstract::ValMatch;
use super::return_enums::EitherComplex;
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
Expand Down Expand Up @@ -598,6 +601,45 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {

Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self))
}

fn validate_complex<'a>(&'a self, strict: bool, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
if let Ok(complex) = self.downcast::<PyComplex>() {
return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned())));
}
if strict {
return Err(ValError::new(
ErrorType::IsInstanceOf {
class: PyComplex::type_object_bound(py)
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "complex".to_owned()),
context: None,
},
self,
));
}

if let Ok(s) = self.downcast::<PyString>() {
// If input is not a valid complex string, instead of telling users to correct
// the string, it makes more sense to tell them to provide any acceptable value
// since they might have just given values of some incorrect types instead
// of actually trying some complex strings.
if let Ok(c) = string_to_complex(s, self) {
return Ok(ValidationMatch::lax(EitherComplex::Py(c)));
}
} else if self.is_exact_instance_of::<PyFloat>() {
return Ok(ValidationMatch::lax(EitherComplex::Complex([
self.extract::<f64>().unwrap(),
0.0,
])));
} else if self.is_exact_instance_of::<PyInt>() {
return Ok(ValidationMatch::lax(EitherComplex::Complex([
self.extract::<i64>().unwrap() as f64,
0.0,
])));
}
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
}

impl<'py> BorrowInput<'py> for Bound<'py, PyAny> {
Expand Down
9 changes: 9 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}
use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{Never, ValMatch};
use super::return_enums::EitherComplex;
use super::shared::{str_as_bool, str_as_float, str_as_int};
use super::{
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
Expand Down Expand Up @@ -225,6 +227,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, _strict: bool, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

impl<'py> BorrowInput<'py> for StringMapping<'py> {
Expand Down
29 changes: 28 additions & 1 deletion src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::intern;
use pyo3::prelude::*;
#[cfg(not(PyPy))]
use pyo3::types::PyFunction;
use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};

use serde::{ser::Error, Serialize, Serializer};

Expand Down Expand Up @@ -724,3 +724,30 @@ impl ToPyObject for Int {
}
}
}

#[derive(Clone)]
pub enum EitherComplex<'a> {
Complex([f64; 2]),
Py(Bound<'a, PyComplex>),
}

impl<'a> IntoPy<PyObject> for EitherComplex<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
match self {
Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py),
Self::Py(c) => c.into_py(py),
}
}
}

impl<'a> EitherComplex<'a> {
pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] {
match self {
EitherComplex::Complex(f) => *f,
EitherComplex::Py(f) => [
f.getattr(intern!(py, "real")).unwrap().extract().unwrap(),
f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(),
],
}
}
}
24 changes: 23 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
Expand Down Expand Up @@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known(
}
PyList::new_bound(py, items).into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
ObType::Unknown => {
Expand Down Expand Up @@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known(
);
iter.into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(Int),
ObType::Bool => serialize!(bool),
ObType::Complex => {
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry(&"real", &v.real())?;
map.serialize_entry(&"imag", &v.imag())?;
map.end()
}
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>().map_err(py_err_se_err)?;
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
Expand Down Expand Up @@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
Ok(Cow::Owned(key_build.finish()))
}
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
}
ObType::Dataclass | ObType::PydanticSerializable => {
Expand Down
Loading

0 comments on commit 3d8295e

Please sign in to comment.