From 3cac77f828a64b991ec7abb1f54301594532b700 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 19 Jul 2024 13:39:52 +0200 Subject: [PATCH] fix: Adjust min/max items to valid lengths for Set[Enum] fields For `Set[Enum]` fields with a limited maximum length, adjust `min_items` and `max_items` to be within the range of valid lengths. --- .../constrained_collections.py | 6 ++++++ tests/test_complex_types.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/polyfactory/value_generators/constrained_collections.py b/polyfactory/value_generators/constrained_collections.py index 3953313b..626a2a08 100644 --- a/polyfactory/value_generators/constrained_collections.py +++ b/polyfactory/value_generators/constrained_collections.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import EnumType from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast from polyfactory.exceptions import ParameterException @@ -39,6 +40,11 @@ def handle_constrained_collection( min_items = abs(min_items if min_items is not None else (max_items or 0)) max_items = abs(max_items if max_items is not None else min_items + 1) + if isinstance(field_meta.annotation, EnumType): + max_items = len(field_meta.annotation) + if min_items > max_items: + min_items = max_items + if max_items < min_items: msg = "max_items must be larger or equal to min_items" raise ParameterException(msg) diff --git a/tests/test_complex_types.py b/tests/test_complex_types.py index dcc699ff..ac5232cf 100644 --- a/tests/test_complex_types.py +++ b/tests/test_complex_types.py @@ -155,6 +155,25 @@ class MyFactory(ModelFactory): assert result.animal_list +def test_complex_typing_with_enum_set() -> None: + class Animal(str, Enum): + DOG = "Dog" + CAT = "Cat" + MONKEY = "Monkey" + + class MyModel(BaseModel): + animal_list: Set[Animal] + + class MyFactory(ModelFactory): + __model__ = MyModel + __randomize_collection_length__ = True + __min_collection_length__ = len(Animal) + 1 + __min_collection_length__ = len(Animal) + 2 + + result = MyFactory.build() + assert len(result.animal_list) == len(Animal) + + def test_union_literal() -> None: class MyModel(BaseModel): x: Union[int, Literal["a", "b", "c"]]