Skip to content

Commit

Permalink
Adding tagged union serializer 🚀 (pydantic#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Aug 15, 2024
1 parent 3d8295e commit fdd1e85
Show file tree
Hide file tree
Showing 12 changed files with 430 additions and 193 deletions.
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod union;
43 changes: 43 additions & 0 deletions src/common/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use pyo3::prelude::*;
use pyo3::{PyTraverseError, PyVisit};

use crate::lookup_key::LookupKey;
use crate::py_gc::PyGcTraverse;

#[derive(Debug, Clone)]
pub enum Discriminator {
/// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases
LookupKey(LookupKey),
/// call a function to find the tag to use
Function(PyObject),
}

impl Discriminator {
pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult<Self> {
if raw.is_callable() {
return Ok(Self::Function(raw.to_object(py)));
}

let lookup_key = LookupKey::from_py(py, raw, None)?;
Ok(Self::LookupKey(lookup_key))
}

pub fn to_string_py(&self, py: Python) -> PyResult<String> {
match self {
Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)),
Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()),
}
}
}

impl PyGcTraverse for Discriminator {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Self::Function(obj) => visit.call(obj)?,
Self::LookupKey(_) => {}
}
Ok(())
}
}

pub(crate) const SMALL_UNION_THRESHOLD: usize = 4;
4 changes: 0 additions & 4 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ pub trait ValidatedDict<'py> {
where
Self: 'a;
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>>;
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>>;
// FIXME this is a bit of a leaky abstraction
fn is_py_get_attr(&self) -> bool {
false
Expand Down Expand Up @@ -282,9 +281,6 @@ impl<'py> ValidatedDict<'py> for Never {
fn get_item<'k>(&self, _key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
unreachable!()
}
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
unreachable!()
}
fn iterate<'a, R>(
&'a self,
_consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
4 changes: 0 additions & 4 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,6 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> {
key.json_get(self)
}

fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
None
}

fn iterate<'a, R>(
&'a self,
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
7 changes: 0 additions & 7 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,13 +817,6 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> {
matches!(self, Self::GetAttr(..))
}

fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
match self {
Self::Dict(dict) => Some(dict),
_ => None,
}
}

fn iterate<'a, R>(
&'a self,
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
3 changes: 0 additions & 3 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,6 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> {
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
key.py_get_string_mapping_item(&self.0)
}
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
None
}
fn iterate<'a, R>(
&'a self,
consumer: impl super::ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod py_gc;

mod argument_markers;
mod build_tools;
mod common;
mod definitions;
mod errors;
mod input;
Expand Down
3 changes: 2 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ combined_serializer! {
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
find_only: {
super::type_serializers::union::TaggedUnionBuilder;
super::type_serializers::other::ChainBuilder;
super::type_serializers::other::CustomErrorBuilder;
super::type_serializers::other::CallBuilder;
Expand Down Expand Up @@ -138,6 +137,7 @@ combined_serializer! {
Json: super::type_serializers::json::JsonSerializer;
JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer;
Union: super::type_serializers::union::UnionSerializer;
TaggedUnion: super::type_serializers::union::TaggedUnionSerializer;
Literal: super::type_serializers::literal::LiteralSerializer;
Enum: super::type_serializers::enum_::EnumSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Expand Down Expand Up @@ -247,6 +247,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TaggedUnion(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
Expand Down
Loading

0 comments on commit fdd1e85

Please sign in to comment.