Skip to content

Commit

Permalink
Fix issue with union parsing of enums (#6440)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu authored Jul 5, 2023
1 parent 6bebd9c commit c9292af
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pydantic/_internal/_std_types_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
update_json_schema(original_schema, updates)
return json_schema

strict_python_schema = core_schema.is_instance_schema(enum_type)
if use_enum_values:
strict_python_schema = core_schema.chain_schema(
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
)

to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
if issubclass(enum_type, int):
# this handles `IntEnum`, and also `Foobar(int, Enum)`
Expand All @@ -103,28 +109,26 @@ def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# Disallow float from JSON due to strict mode
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
python_schema=core_schema.is_instance_schema(enum_type),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
updates['type'] = 'string'
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
python_schema=core_schema.is_instance_schema(enum_type),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, float):
updates['type'] = 'numeric'
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
python_schema=core_schema.is_instance_schema(enum_type),
python_schema=strict_python_schema,
)
else:
lax = to_enum_validator
strict = core_schema.json_or_python_schema(
json_schema=to_enum_validator, python_schema=core_schema.is_instance_schema(enum_type)
)
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
return core_schema.lax_or_strict_schema(
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Set,
Type,
TypeVar,
Union,
get_type_hints,
)
from uuid import UUID, uuid4
Expand Down Expand Up @@ -669,6 +670,32 @@ class Model(BaseModel):
]


def test_strict_enum_values():
class MyEnum(Enum):
val = 'val'

class Model(BaseModel):
model_config = ConfigDict(use_enum_values=True)
x: MyEnum

assert Model.model_validate({'x': MyEnum.val}, strict=True).x == 'val'


def test_union_enum_values():
class MyEnum(Enum):
val = 'val'

class NormalModel(BaseModel):
x: Union[MyEnum, int]

class UseEnumValuesModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)
x: Union[MyEnum, int]

assert NormalModel(x=MyEnum.val).x != 'val'
assert UseEnumValuesModel(x=MyEnum.val).x == 'val'


def test_enum_raw():
FooEnum = Enum('FooEnum', {'foo': 'foo', 'bar': 'bar'})

Expand Down

0 comments on commit c9292af

Please sign in to comment.