Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more explicit warning regarding serialization warning for missing fields #1415

Merged
merged 11 commits into from
Aug 22, 2024
46 changes: 0 additions & 46 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use core::fmt;
use std::borrow::Cow;

use pyo3::prelude::*;

mod line_error;
Expand Down Expand Up @@ -33,46 +30,3 @@ pub fn py_err_string(py: Python, err: PyErr) -> String {
Err(_) => "Unknown Error".to_string(),
}
}

// TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary
// with builtin methods once https://github.com/rust-lang/rust/issues/93743 is resolved
// These are just copy pasted from the current implementation
const fn is_utf8_char_boundary(value: u8) -> bool {
// This is bit magic equivalent to: b < 128 || b >= 192
(value as i8) >= -0x40
}

pub fn floor_char_boundary(value: &str, index: usize) -> usize {
if index >= value.len() {
value.len()
} else {
let lower_bound = index.saturating_sub(3);
let new_index = value.as_bytes()[lower_bound..=index]
.iter()
.rposition(|b| is_utf8_char_boundary(*b));

// SAFETY: we know that the character boundary will be within four bytes
unsafe { lower_bound + new_index.unwrap_unchecked() }
}
}

pub fn ceil_char_boundary(value: &str, index: usize) -> usize {
let upper_bound = Ord::min(index + 4, value.len());
value.as_bytes()[index..upper_bound]
.iter()
.position(|b| is_utf8_char_boundary(*b))
.map_or(upper_bound, |pos| pos + index)
}

pub fn write_truncated_to_50_bytes<F: fmt::Write>(f: &mut F, val: Cow<'_, str>) -> std::fmt::Result {
if val.len() > 50 {
write!(
f,
"{}...{}",
&val[0..floor_char_boundary(&val, 25)],
&val[ceil_char_boundary(&val, val.len() - 24)..]
)
} else {
write!(f, "{val}")
}
}
4 changes: 2 additions & 2 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::errors::LocItem;
use crate::get_pydantic_version;
use crate::input::InputType;
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState};
use crate::tools::{safe_repr, SchemaDict};
use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict};

use super::line_error::ValLineError;
use super::location::Location;
Expand Down Expand Up @@ -526,7 +526,7 @@ impl PyLineError {
let input_value = self.input_value.bind(py);
let input_str = safe_repr(input_value);
write!(output, ", input_value=")?;
super::write_truncated_to_50_bytes(&mut output, input_str.to_cow())?;
write_truncated_to_limited_bytes(&mut output, &input_str.to_string(), 50)?;

if let Ok(type_) = input_value.get_type().qualname() {
write!(output, ", input_type={type_}")?;
Expand Down
11 changes: 3 additions & 8 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::recursion_guard::ContainsRecursionState;
use crate::recursion_guard::RecursionError;
use crate::recursion_guard::RecursionGuard;
use crate::recursion_guard::RecursionState;
use crate::tools::safe_repr;
use crate::tools::truncate_safe_repr;
use crate::PydanticSerializationError;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -426,15 +426,10 @@ impl CollectWarnings {
.qualname()
.unwrap_or_else(|_| PyString::new_bound(value.py(), "<unknown python object>"));

let input_str = safe_repr(value);
let mut value_str = String::with_capacity(100);
value_str.push_str("with value `");
crate::errors::write_truncated_to_50_bytes(&mut value_str, input_str.to_cow())
.expect("Writing to a `String` failed");
value_str.push('`');
let value_str = truncate_safe_repr(value, None);

self.add_warning(format!(
"Expected `{field_type}` but got `{type_name}` {value_str} - serialized value may not be as expected"
"Expected `{field_type}` but got `{type_name}` with value `{value_str}` - serialized value may not be as expected"
));
}
}
Expand Down
20 changes: 19 additions & 1 deletion src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use smallvec::SmallVec;

use crate::serializers::extra::SerCheck;
use crate::serializers::DuckTypingSerMode;
use crate::tools::truncate_safe_repr;
use crate::PydanticSerializationUnexpectedValue;

use super::computed_fields::ComputedFields;
Expand Down Expand Up @@ -210,7 +211,24 @@ impl GeneralFieldsSerializer {
// Check for missing fields, we can't have extra fields here
&& self.required_fields > used_req_fields
{
Err(PydanticSerializationUnexpectedValue::new_err(None))
let required_fields = self.required_fields;
let type_name = match extra.model {
Some(model) => model
.get_type()
.qualname()
.ok()
.unwrap_or_else(|| PyString::new_bound(py, "<unknown python object>"))
.to_string(),
None => "<unknown python object>".to_string(),
};
let field_value = match extra.model {
Some(model) => truncate_safe_repr(model, Some(100)),
None => "<unknown python object>".to_string(),
};

Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected."
))))
} else {
Ok(output_dict)
}
Expand Down
3 changes: 1 addition & 2 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ impl TypeSerializer for ModelSerializer {
) -> PyResult<PyObject> {
let model = Some(value);
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();

let model_extra = Extra {
model,
field_name: None,
duck_typing_ser_mode,
..*extra
};
Expand Down Expand Up @@ -221,7 +221,6 @@ impl TypeSerializer for ModelSerializer {
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();
let model_extra = Extra {
model,
field_name: None,
duck_typing_ser_mode,
..*extra
};
Expand Down
12 changes: 3 additions & 9 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::errors::write_truncated_to_50_bytes;
use crate::lookup_key::LookupKey;
use crate::serializers::type_serializers::py_err_se_err;
use crate::tools::{safe_repr, SchemaDict};
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
Expand Down Expand Up @@ -446,15 +445,10 @@ impl TaggedUnionSerializer {
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let input_str = safe_repr(value);
let mut value_str = String::with_capacity(100);
value_str.push_str("with value `");
write_truncated_to_50_bytes(&mut value_str, input_str.to_cow()).expect("Writing to a `String` failed");
value_str.push('`');

let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization {value_str} - defaulting to left to right union serialization."
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
Expand Down
64 changes: 54 additions & 10 deletions src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use core::fmt;

use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -96,15 +96,6 @@ pub enum ReprOutput<'py> {
Fallback(String),
}

impl ReprOutput<'_> {
pub fn to_cow(&self) -> Cow<'_, str> {
match self {
ReprOutput::Python(s) => s.to_string_lossy(),
ReprOutput::Fallback(s) => s.into(),
}
}
}

impl std::fmt::Display for ReprOutput<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -124,6 +115,15 @@ pub fn safe_repr<'py>(v: &Bound<'py, PyAny>) -> ReprOutput<'py> {
}
}

pub fn truncate_safe_repr(v: &Bound<'_, PyAny>, max_len: Option<usize>) -> String {
let max_len = max_len.unwrap_or(50); // default to 100 bytes
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
let input_str = safe_repr(v);
let mut limited_str = String::with_capacity(max_len);
write_truncated_to_limited_bytes(&mut limited_str, &input_str.to_string(), max_len)
.expect("Writing to a `String` failed");
limited_str
}

pub fn extract_i64(v: &Bound<'_, PyAny>) -> Option<i64> {
#[cfg(PyPy)]
if !v.is_instance_of::<pyo3::types::PyInt>() {
Expand All @@ -146,3 +146,47 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach
pystring_fast_new(py, s, ascii_only)
}
}

// TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary
// with builtin methods once https://github.com/rust-lang/rust/issues/93743 is resolved
// These are just copy pasted from the current implementation
const fn is_utf8_char_boundary(value: u8) -> bool {
// This is bit magic equivalent to: b < 128 || b >= 192
(value as i8) >= -0x40
}

pub fn floor_char_boundary(value: &str, index: usize) -> usize {
if index >= value.len() {
value.len()
} else {
let lower_bound = index.saturating_sub(3);
let new_index = value.as_bytes()[lower_bound..=index]
.iter()
.rposition(|b| is_utf8_char_boundary(*b));

// SAFETY: we know that the character boundary will be within four bytes
unsafe { lower_bound + new_index.unwrap_unchecked() }
}
}

pub fn ceil_char_boundary(value: &str, index: usize) -> usize {
let upper_bound = Ord::min(index + 4, value.len());
value.as_bytes()[index..upper_bound]
.iter()
.position(|b| is_utf8_char_boundary(*b))
.map_or(upper_bound, |pos| pos + index)
}

pub fn write_truncated_to_limited_bytes<F: fmt::Write>(f: &mut F, val: &str, max_len: usize) -> std::fmt::Result {
if val.len() > max_len {
let mid_point = max_len.div_ceil(2);
write!(
f,
"{}...{}",
&val[0..floor_char_boundary(val, mid_point)],
&val[ceil_char_boundary(val, val.len() - (mid_point - 1))..]
)
} else {
write!(f, "{val}")
}
}
68 changes: 61 additions & 7 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
import pytest
from dirty_equals import IsJson

from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema
from pydantic_core import (
PydanticSerializationError,
PydanticSerializationUnexpectedValue,
SchemaSerializer,
SchemaValidator,
core_schema,
)

from ..conftest import plain_repr

Expand Down Expand Up @@ -1084,20 +1090,68 @@ class Model:


def test_no_warn_on_exclude() -> None:
warnings.simplefilter('error')
with warnings.catch_warnings():
warnings.simplefilter('error')

s = SchemaSerializer(
core_schema.model_schema(
BasicModel,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.int_schema()),
'b': core_schema.model_field(core_schema.int_schema()),
}
),
)
)

value = BasicModel(a=0, b=1)
assert s.to_python(value, exclude={'b'}) == {'a': 0}
assert s.to_python(value, mode='json', exclude={'b'}) == {'a': 0}


def test_warn_on_missing_field() -> None:
class AModel(BasicModel): ...

class BModel(BasicModel): ...

s = SchemaSerializer(
core_schema.model_schema(
BasicModel,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.int_schema()),
'b': core_schema.model_field(core_schema.int_schema()),
'root': core_schema.model_field(
core_schema.tagged_union_schema(
choices={
'a': core_schema.model_schema(
AModel,
core_schema.model_fields_schema(
{
'type': core_schema.model_field(core_schema.literal_schema(['a'])),
'a': core_schema.model_field(core_schema.int_schema()),
}
),
),
'b': core_schema.model_schema(
BModel,
core_schema.model_fields_schema(
{
'type': core_schema.model_field(core_schema.literal_schema(['b'])),
'b': core_schema.model_field(core_schema.int_schema()),
}
),
),
},
discriminator='type',
)
),
}
),
)
)

value = BasicModel(a=0, b=1)
assert s.to_python(value, exclude={'b'}) == {'a': 0}
assert s.to_python(value, mode='json', exclude={'b'}) == {'a': 0}
with pytest.raises(
PydanticSerializationUnexpectedValue, match='Expected 2 fields but got 1 for type `.*AModel` with value `.*`.+'
):
value = BasicModel(root=AModel(type='a'))
s.to_python(value)
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
Loading