Skip to content

tidy up tagged_union_schema #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,8 +2466,8 @@ class TaggedUnionSchema(TypedDict, total=False):


def tagged_union_schema(
choices: Dict[Hashable, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable],
choices: Dict[Any, CoreSchema],
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Any],
*,
custom_error_type: str | None = None,
custom_error_message: str | None = None,
Expand Down
29 changes: 14 additions & 15 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Validator for things inside of a typing.Literal[]
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
use core::fmt::Debug;
use std::cmp::Ordering;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyList};
Expand Down Expand Up @@ -35,7 +34,7 @@ pub struct LiteralLookup<T: Debug> {
// Catch all for hashable types like Enum and bytes (the latter only because it is seldom used)
expected_py_dict: Option<Py<PyDict>>,
// Catch all for unhashable types like list
expected_py_list: Option<Py<PyList>>,
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this might be more performant - are there any benchmarks that reflect that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No benchmarks, but otherwise yes, we don't need to make Python FFI calls here to maintain this internal state :)


pub values: Vec<T>,
}
Expand All @@ -46,7 +45,7 @@ impl<T: Debug> LiteralLookup<T> {
let mut expected_int = AHashMap::new();
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
let expected_py_dict = PyDict::new_bound(py);
let expected_py_list = PyList::empty_bound(py);
let mut expected_py_values = Vec::new();
let mut values = Vec::new();
for (k, v) in expected {
let id = values.len();
Expand All @@ -71,7 +70,7 @@ impl<T: Debug> LiteralLookup<T> {
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
expected_str.insert(str.to_string(), id);
} else if expected_py_dict.set_item(&k, id).is_err() {
expected_py_list.append((&k, id))?;
expected_py_values.push((k.as_unbound().clone_ref(py), id));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to learn more about rust syntax here - could you explain this as_unbound.clone_ref(py)? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline - is PyO3 code to change from Bound<'py, PyAny> to Py<PyAny>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I'll approve!

}
}

Expand All @@ -92,9 +91,9 @@ impl<T: Debug> LiteralLookup<T> {
true => None,
false => Some(expected_py_dict.into()),
},
expected_py_list: match expected_py_list.is_empty() {
expected_py_values: match expected_py_values.is_empty() {
true => None,
false => Some(expected_py_list.into()),
false => Some(expected_py_values),
},
values,
})
Expand Down Expand Up @@ -143,23 +142,23 @@ impl<T: Debug> LiteralLookup<T> {
}
}
}
// cache py_input if needed, since we might need it for multiple lookups
let mut py_input = None;
if let Some(expected_py_dict) = &self.expected_py_dict {
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
// inputs will produce a TypeError, which in this case we just want to treat equivalently
// to a failed lookup
if let Ok(Some(v)) = expected_py_dict.bind(py).get_item(input) {
if let Ok(Some(v)) = expected_py_dict.bind(py).get_item(&*py_input) {
let id: usize = v.extract().unwrap();
return Ok(Some((input, &self.values[id])));
}
};
if let Some(expected_py_list) = &self.expected_py_list {
for item in expected_py_list.bind(py) {
let (k, id): (Bound<PyAny>, usize) = item.extract()?;
if k.compare(input.to_object(py).bind(py))
.unwrap_or(Ordering::Less)
.is_eq()
{
return Ok(Some((input, &self.values[id])));
if let Some(expected_py_values) = &self.expected_py_values {
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
for (k, id) in expected_py_values {
if k.bind(py).eq(&*py_input).unwrap_or(false) {
return Ok(Some((input, &self.values[*id])));
}
}
};
Expand Down
2 changes: 0 additions & 2 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,9 @@ impl BuildValidator for TaggedUnionValidator {
let mut tags_repr = String::with_capacity(50);
let mut descr = String::with_capacity(50);
let mut first = true;
let mut discriminators = Vec::with_capacity(choices.len());
let schema_choices: Bound<PyDict> = schema.get_as_req(intern!(py, "choices"))?;
let mut lookup_map = Vec::with_capacity(choices.len());
for (choice_key, choice_schema) in schema_choices {
discriminators.push(choice_key.clone());
let validator = build_validator(&choice_schema, config, definitions)?;
let tag_repr = choice_key.repr()?.to_string();
if first {
Expand Down
Loading