Skip to content

Commit

Permalink
Support wider variety of enum validation cases (pydantic#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Sep 20, 2024
1 parent e0b4c94 commit 8c1a0da
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
},
input,
));
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
state.floor_exactness(Exactness::Lax);
}

state.floor_exactness(Exactness::Lax);

if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
return Ok(v);
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
return Ok(res);
} else if let Some(ref missing) = self.missing {
state.floor_exactness(Exactness::Lax);
let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| {
ValError::new(
ErrorType::Enum {
Expand All @@ -146,6 +150,7 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
return Err(type_error.into());
}
}

Err(ValError::new(
ErrorType::Enum {
expected: self.expected_repr.clone(),
Expand Down
143 changes: 143 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,145 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE
assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


@pytest.mark.skipif(
sys.version_info >= (3, 13),
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
)
def test_enum_int_validation_should_succeed_for_custom_type():
class AnyWrapper:
def __init__(self, value):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyEnum(Enum):
VALUE = 999
SECOND_VALUE = 1000000
THIRD_VALUE = 'Py03'

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE


def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
class MyEnum(Enum):
VALUE = '1'

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
class MyEnum(Enum):
VALUE = 1

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v.validate_python(Decimal(1.1))


def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
class MyEnum(Enum):
VALUE = 1

class MyClass:
def __init__(self, value):
self.value = value

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(MyClass(1))


def support_custom_new_method() -> None:
"""Demonstrates support for custom new methods, as well as conceptually, multi-value enums without dependency on a 3rd party lib for testing."""

class Animal(Enum):
CAT = 'cat', 'meow'
DOG = 'dog', 'woof'

def __new__(cls, species: str, sound: str):
obj = object.__new__(cls)

obj._value_ = species
obj._all_values = (species, sound)

obj.species = species
obj.sound = sound

cls._value2member_map_[sound] = obj

return obj

v = SchemaValidator(core_schema.enum_schema(Animal, list(Animal.__members__.values())))
assert v.validate_python('cat') is Animal.CAT
assert v.validate_python('meow') is Animal.CAT
assert v.validate_python('dog') is Animal.DOG
assert v.validate_python('woof') is Animal.DOG
65 changes: 65 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import sys
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1314,66 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


@pytest.mark.skipif(
sys.version_info >= (3, 13),
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
)
def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
# GIVEN
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2
enum_value_3 = 3

class IntWrappable:
def __init__(self, value: int):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

# WHEN
v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_3': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)

# THEN
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
m = v.validate_python(
{
'enum_field': Decimal(1),
'enum_field_2': Decimal(2),
'enum_field_3': IntWrappable(3),
}
)
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))

0 comments on commit 8c1a0da

Please sign in to comment.