Skip to content

Commit c9292af

Browse files
authored
Fix issue with union parsing of enums (#6440)
1 parent 6bebd9c commit c9292af

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

pydantic/_internal/_std_types_schema.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
9595
update_json_schema(original_schema, updates)
9696
return json_schema
9797

98+
strict_python_schema = core_schema.is_instance_schema(enum_type)
99+
if use_enum_values:
100+
strict_python_schema = core_schema.chain_schema(
101+
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
102+
)
103+
98104
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
99105
if issubclass(enum_type, int):
100106
# this handles `IntEnum`, and also `Foobar(int, Enum)`
@@ -103,28 +109,26 @@ def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
103109
# Disallow float from JSON due to strict mode
104110
strict = core_schema.json_or_python_schema(
105111
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
106-
python_schema=core_schema.is_instance_schema(enum_type),
112+
python_schema=strict_python_schema,
107113
)
108114
elif issubclass(enum_type, str):
109115
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
110116
updates['type'] = 'string'
111117
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
112118
strict = core_schema.json_or_python_schema(
113119
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
114-
python_schema=core_schema.is_instance_schema(enum_type),
120+
python_schema=strict_python_schema,
115121
)
116122
elif issubclass(enum_type, float):
117123
updates['type'] = 'numeric'
118124
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
119125
strict = core_schema.json_or_python_schema(
120126
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
121-
python_schema=core_schema.is_instance_schema(enum_type),
127+
python_schema=strict_python_schema,
122128
)
123129
else:
124130
lax = to_enum_validator
125-
strict = core_schema.json_or_python_schema(
126-
json_schema=to_enum_validator, python_schema=core_schema.is_instance_schema(enum_type)
127-
)
131+
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
128132
return core_schema.lax_or_strict_schema(
129133
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
130134
)

tests/test_main.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Set,
1919
Type,
2020
TypeVar,
21+
Union,
2122
get_type_hints,
2223
)
2324
from uuid import UUID, uuid4
@@ -669,6 +670,32 @@ class Model(BaseModel):
669670
]
670671

671672

673+
def test_strict_enum_values():
674+
class MyEnum(Enum):
675+
val = 'val'
676+
677+
class Model(BaseModel):
678+
model_config = ConfigDict(use_enum_values=True)
679+
x: MyEnum
680+
681+
assert Model.model_validate({'x': MyEnum.val}, strict=True).x == 'val'
682+
683+
684+
def test_union_enum_values():
685+
class MyEnum(Enum):
686+
val = 'val'
687+
688+
class NormalModel(BaseModel):
689+
x: Union[MyEnum, int]
690+
691+
class UseEnumValuesModel(BaseModel):
692+
model_config = ConfigDict(use_enum_values=True)
693+
x: Union[MyEnum, int]
694+
695+
assert NormalModel(x=MyEnum.val).x != 'val'
696+
assert UseEnumValuesModel(x=MyEnum.val).x == 'val'
697+
698+
672699
def test_enum_raw():
673700
FooEnum = Enum('FooEnum', {'foo': 'foo', 'bar': 'bar'})
674701

0 commit comments

Comments
 (0)