diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 000000000..11f2e1ece --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1 @@ +pub(crate) mod union; diff --git a/src/common/union.rs b/src/common/union.rs new file mode 100644 index 000000000..17fe9ad90 --- /dev/null +++ b/src/common/union.rs @@ -0,0 +1,43 @@ +use pyo3::prelude::*; +use pyo3::{PyTraverseError, PyVisit}; + +use crate::lookup_key::LookupKey; +use crate::py_gc::PyGcTraverse; + +#[derive(Debug, Clone)] +pub enum Discriminator { + /// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases + LookupKey(LookupKey), + /// call a function to find the tag to use + Function(PyObject), +} + +impl Discriminator { + pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult { + if raw.is_callable() { + return Ok(Self::Function(raw.to_object(py))); + } + + let lookup_key = LookupKey::from_py(py, raw, None)?; + Ok(Self::LookupKey(lookup_key)) + } + + pub fn to_string_py(&self, py: Python) -> PyResult { + match self { + Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)), + Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()), + } + } +} + +impl PyGcTraverse for Discriminator { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + match self { + Self::Function(obj) => visit.call(obj)?, + Self::LookupKey(_) => {} + } + Ok(()) + } +} + +pub(crate) const SMALL_UNION_THRESHOLD: usize = 4; diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 24a2e0239..d85ae0a56 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -239,7 +239,6 @@ pub trait ValidatedDict<'py> { where Self: 'a; fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult)>>; - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>>; // FIXME this is a bit of a leaky abstraction fn is_py_get_attr(&self) -> bool { false @@ -282,9 +281,6 @@ impl<'py> ValidatedDict<'py> for Never { fn get_item<'k>(&self, _key: &'k LookupKey) -> ValResult)>> { unreachable!() } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - unreachable!() - } fn iterate<'a, R>( &'a self, _consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 5eb505f6e..66dd0b054 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -509,10 +509,6 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> { key.json_get(self) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - None - } - fn iterate<'a, R>( &'a self, consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 46c32a9de..6c0528222 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -817,13 +817,6 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> { matches!(self, Self::GetAttr(..)) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - match self { - Self::Dict(dict) => Some(dict), - _ => None, - } - } - fn iterate<'a, R>( &'a self, consumer: impl ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 7adcaeb28..190dd46bf 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -293,9 +293,6 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> { fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult)>> { key.py_get_string_mapping_item(&self.0) } - fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> { - None - } fn iterate<'a, R>( &'a self, consumer: impl super::ConsumeIterator, Self::Item<'a>)>, Output = R>, diff --git a/src/lib.rs b/src/lib.rs index eb486da66..d1e042863 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ mod py_gc; mod argument_markers; mod build_tools; +mod common; mod definitions; mod errors; mod input; diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index cb12f8840..8eb54c837 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -88,7 +88,6 @@ combined_serializer! { // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer find_only: { - super::type_serializers::union::TaggedUnionBuilder; super::type_serializers::other::ChainBuilder; super::type_serializers::other::CustomErrorBuilder; super::type_serializers::other::CallBuilder; @@ -138,6 +137,7 @@ combined_serializer! { Json: super::type_serializers::json::JsonSerializer; JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer; Union: super::type_serializers::union::UnionSerializer; + TaggedUnion: super::type_serializers::union::TaggedUnionSerializer; Literal: super::type_serializers::literal::LiteralSerializer; Enum: super::type_serializers::enum_::EnumSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; @@ -247,6 +247,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit), CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::TaggedUnion(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 2915115c3..23aa874fa 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -1,3 +1,4 @@ +use ahash::AHashMap as HashMap; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; @@ -5,8 +6,12 @@ use smallvec::SmallVec; use std::borrow::Cow; use crate::build_tools::py_schema_err; +use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; -use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY}; +use crate::errors::write_truncated_to_50_bytes; +use crate::lookup_key::LookupKey; +use crate::serializers::type_serializers::py_err_se_err; +use crate::tools::{safe_repr, SchemaDict}; use crate::PydanticSerializationUnexpectedValue; use super::{ @@ -68,20 +73,33 @@ impl UnionSerializer { impl_py_gc_traverse!(UnionSerializer { choices }); -impl TypeSerializer for UnionSerializer { - fn to_python( - &self, - value: &Bound<'_, PyAny>, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - ) -> PyResult { - // try the serializers in left to right order with error_on fallback=true - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); +fn to_python( + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, + retry_with_lax_check: bool, +) -> PyResult { + // try the serializers in left to right order with error_on fallback=true + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - for comb_serializer in &self.choices { + for comb_serializer in choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(value.py()) { + true => (), + false => errors.push(err), + }, + } + } + + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(value.py()) { @@ -90,33 +108,40 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), - }, - } - } - } - - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } + } - extra.warnings.on_fallback_py(self.get_name(), value, extra)?; - infer_to_python(value, include, exclude, extra) + for err in &errors { + extra.warnings.custom_warning(err.to_string()); } - fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); + extra.warnings.on_fallback_py(name, value, extra)?; + infer_to_python(value, include, exclude, extra) +} - for comb_serializer in &self.choices { +fn json_key<'a>( + key: &'a Bound<'_, PyAny>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, + retry_with_lax_check: bool, +) -> PyResult> { + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + + for comb_serializer in choices { + match comb_serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(key.py()) { + true => (), + false => errors.push(err), + }, + } + } + + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { match comb_serializer.json_key(key, &new_extra) { Ok(v) => return Ok(v), Err(err) => match err.is_instance_of::(key.py()) { @@ -125,25 +150,233 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.json_key(key, &new_extra) { + } + + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + + extra.warnings.on_fallback_py(name, key, extra)?; + infer_json_key(key, extra) +} + +#[allow(clippy::too_many_arguments)] +fn serde_serialize( + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], + name: &str, + retry_with_lax_check: bool, +) -> Result { + let py = value.py(); + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); + + for comb_serializer in choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { + true => (), + false => errors.push(err), + }, + } + } + + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + for comb_serializer in choices { + match comb_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => match err.is_instance_of::(py) { + true => (), + false => errors.push(err), + }, + } + } + } + + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + + extra.warnings.on_fallback_ser::(name, value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) +} + +impl TypeSerializer for UnionSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + to_python( + value, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + json_key(key, extra, &self.choices, self.get_name(), self.retry_with_lax_check()) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + serde_serialize( + value, + serializer, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) + } + + fn get_name(&self) -> &str { + &self.name + } + + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) + } +} + +#[derive(Debug)] +pub struct TaggedUnionSerializer { + discriminator: Discriminator, + lookup: HashMap, + choices: Vec, + name: String, +} + +impl BuildSerializer for TaggedUnionSerializer { + const EXPECTED_TYPE: &'static str = "tagged-union"; + + fn build( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?; + + // TODO: guarantee at least 1 choice + let choices_map: Bound = schema.get_as_req(intern!(py, "choices"))?; + let mut lookup = HashMap::with_capacity(choices_map.len()); + let mut choices = Vec::with_capacity(choices_map.len()); + + for (idx, (choice_key, choice_schema)) in choices_map.into_iter().enumerate() { + let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions)?; + choices.push(serializer); + lookup.insert(choice_key.to_string(), idx); + } + + let descr = choices + .iter() + .map(TypeSerializer::get_name) + .collect::>() + .join(", "); + + Ok(Self { + discriminator, + lookup, + choices, + name: format!("TaggedUnion[{descr}]"), + } + .into()) + } +} + +impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, choices }); + +impl TypeSerializer for TaggedUnionSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + if let Some(tag) = self.get_discriminator_value(value, extra) { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let serializer = &self.choices[serializer_index]; + + match serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), - Err(err) => match err.is_instance_of::(key.py()) { - true => (), - false => errors.push(err), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.to_python(value, include, exclude, &new_extra); + } + } + false => return Err(err), }, } } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + to_python( + value, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + let py = key.py(); + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + if let Some(tag) = self.get_discriminator_value(key, extra) { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let serializer = &self.choices[serializer_index]; + + match serializer.json_key(key, &new_extra) { + Ok(v) => return Ok(v), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + return serializer.json_key(key, &new_extra); + } + } + false => return Err(err), + }, + } + } } - extra.warnings.on_fallback_py(self.get_name(), key, extra)?; - infer_json_key(key, extra) + json_key(key, extra, &self.choices, self.get_name(), self.retry_with_lax_check()) } fn serde_serialize( @@ -157,36 +390,40 @@ impl TypeSerializer for UnionSerializer { let py = value.py(); let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new(); - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), - }, - } - } - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - for comb_serializer in &self.choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { + if let Some(tag) = self.get_discriminator_value(value, extra) { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let selected_serializer = &self.choices[serializer_index]; + + match selected_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => match err.is_instance_of::(value.py()) { - true => (), - false => errors.push(err), + Err(err) => match err.is_instance_of::(py) { + true => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + match selected_serializer.to_python(value, include, exclude, &new_extra) { + Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), + Err(err) => return Err(py_err_se_err(err)), + } + } + } + false => return Err(py_err_se_err(err)), }, } } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - - extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; - infer_serialize(value, serializer, include, exclude, extra) + serde_serialize( + value, + serializer, + include, + exclude, + extra, + &self.choices, + self.get_name(), + self.retry_with_lax_check(), + ) } fn get_name(&self) -> &str { @@ -198,24 +435,29 @@ impl TypeSerializer for UnionSerializer { } } -pub struct TaggedUnionBuilder; - -impl BuildSerializer for TaggedUnionBuilder { - const EXPECTED_TYPE: &'static str = "tagged-union"; - - fn build( - schema: &Bound<'_, PyDict>, - config: Option<&Bound<'_, PyDict>>, - definitions: &mut DefinitionsBuilder, - ) -> PyResult { - let schema_choices: Bound<'_, PyDict> = schema.get_as_req(intern!(schema.py(), "choices"))?; - let mut choices: Vec = Vec::with_capacity(schema_choices.len()); +impl TaggedUnionSerializer { + fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option> { + let py = value.py(); + let discriminator_value = match &self.discriminator { + Discriminator::LookupKey(lookup_key) => match lookup_key { + LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok().map(|obj| obj.to_object(py)), + _ => None, + }, + Discriminator::Function(func) => func.call1(py, (value,)).ok(), + }; + if discriminator_value.is_none() { + let input_str = safe_repr(value); + let mut value_str = String::with_capacity(100); + value_str.push_str("with value `"); + write_truncated_to_50_bytes(&mut value_str, input_str.to_cow()).expect("Writing to a `String` failed"); + value_str.push('`'); - for (_, value) in schema_choices { - if let Ok(choice_schema) = value.downcast::() { - choices.push(CombinedSerializer::build(choice_schema, config, definitions)?); - } + extra.warnings.custom_warning( + format!( + "Failed to get discriminator value for tagged union serialization {value_str} - defaulting to left to right union serialization." + ) + ); } - UnionSerializer::from_choices(choices) + discriminator_value } } diff --git a/src/tools.rs b/src/tools.rs index 121ae3880..adf64c91a 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -146,5 +146,3 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach pystring_fast_new(py, s, ascii_only) } } - -pub(crate) const UNION_ERR_SMALLVEC_CAPACITY: usize = 4; diff --git a/src/validators/union.rs b/src/validators/union.rs index da24868c5..747f6a0cf 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use std::str::FromStr; +use crate::py_gc::PyGcTraverse; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -8,11 +9,10 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; +use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, Input, ValidatedDict}; -use crate::lookup_key::LookupKey; -use crate::py_gc::PyGcTraverse; -use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY}; +use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; @@ -249,7 +249,7 @@ struct ChoiceLineErrors<'a> { enum MaybeErrors<'a> { Custom(&'a CustomError), - Errors(SmallVec<[ChoiceLineErrors<'a>; UNION_ERR_SMALLVEC_CAPACITY]>), + Errors(SmallVec<[ChoiceLineErrors<'a>; SMALL_UNION_THRESHOLD]>), } impl<'a> MaybeErrors<'a> { @@ -295,49 +295,6 @@ impl<'a> MaybeErrors<'a> { } } -#[derive(Debug, Clone)] -enum Discriminator { - /// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases - LookupKey(LookupKey), - /// call a function to find the tag to use - Function(PyObject), - /// Custom discriminator specifically for the root `Schema` union in self-schema - SelfSchema, -} - -impl Discriminator { - fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult { - if raw.is_callable() { - return Ok(Self::Function(raw.to_object(py))); - } else if let Ok(py_str) = raw.downcast::() { - if py_str.to_str()? == "self-schema-discriminator" { - return Ok(Self::SelfSchema); - } - } - - let lookup_key = LookupKey::from_py(py, raw, None)?; - Ok(Self::LookupKey(lookup_key)) - } - - fn to_string_py(&self, py: Python) -> PyResult { - match self { - Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)), - Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()), - Self::SelfSchema => Ok("self-schema".to_string()), - } - } -} - -impl PyGcTraverse for Discriminator { - fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { - match self { - Self::Function(obj) => visit.call(obj)?, - Self::LookupKey(_) | Self::SelfSchema => {} - } - Ok(()) - } -} - #[derive(Debug)] pub struct TaggedUnionValidator { discriminator: Discriminator, @@ -388,11 +345,6 @@ impl BuildValidator for TaggedUnionValidator { let key = intern!(py, "from_attributes"); let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true); - let descr = match discriminator { - Discriminator::SelfSchema => "self-schema".to_string(), - _ => descr, - }; - Ok(Self { discriminator, lookup, @@ -436,9 +388,6 @@ impl Validator for TaggedUnionValidator { self.find_call_validator(py, tag.bind(py), input, state) } } - Discriminator::SelfSchema => { - self.find_call_validator(py, self.self_schema_tag(py, input, state)?.as_any(), input, state) - } } } @@ -448,35 +397,6 @@ impl Validator for TaggedUnionValidator { } impl TaggedUnionValidator { - fn self_schema_tag<'py>( - &self, - py: Python<'py>, - input: &(impl Input<'py> + ?Sized), - state: &mut ValidationState<'_, 'py>, - ) -> ValResult> { - let dict = input.strict_dict()?; - let dict = dict.as_py_dict().expect("self schema is always a Python dictionary"); - let tag = match dict.get_item(intern!(py, "type"))? { - Some(t) => t.downcast_into::()?, - None => return Err(self.tag_not_found(input)), - }; - let tag = tag.to_str()?; - // custom logic to distinguish between different function and tuple schemas - if tag == "function" { - let Some(mode) = dict.get_item(intern!(py, "mode"))? else { - return Err(self.tag_not_found(input)); - }; - let tag = match mode.validate_str(true, false)?.into_inner().as_cow()?.as_ref() { - "plain" => Ok(intern!(py, "function-plain").to_owned()), - "wrap" => Ok(intern!(py, "function-wrap").to_owned()), - _ => Ok(intern!(py, "function").to_owned()), - }; - tag - } else { - Ok(state.maybe_cached_str(py, tag)) - } - } - fn find_call_validator<'py>( &self, py: Python<'py>, diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 83527f648..d97d52f03 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -628,6 +628,55 @@ def test_union_serializer_picks_exact_type_over_subclass_json( assert s.to_json(input_value) == json.dumps(expected_value).encode() +def test_tagged_union() -> None: + @dataclasses.dataclass + class ModelA: + field: int + tag: Literal['a'] = 'a' + + @dataclasses.dataclass + class ModelB: + field: int + tag: Literal['b'] = 'b' + + s = SchemaSerializer( + core_schema.tagged_union_schema( + choices={ + 'a': core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', + [ + core_schema.dataclass_field(name='field', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='tag', schema=core_schema.literal_schema(['a'])), + ], + ), + ['field', 'tag'], + ), + 'b': core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', + [ + core_schema.dataclass_field(name='field', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='tag', schema=core_schema.literal_schema(['b'])), + ], + ), + ['field', 'tag'], + ), + }, + discriminator='tag', + ) + ) + + assert 'TaggedUnionSerializer' in repr(s) + + model_a = ModelA(field=1) + model_b = ModelB(field=1) + assert s.to_python(model_a) == {'field': 1, 'tag': 'a'} + assert s.to_python(model_b) == {'field': 1, 'tag': 'b'} + + def test_union_float_int() -> None: s = SchemaSerializer(core_schema.union_schema([core_schema.float_schema(), core_schema.int_schema()]))