Skip to content

Commit

Permalink
fix: Send correct field_meta to avoid over nesting (#527)
Browse files Browse the repository at this point in the history
Co-authored-by: isaac.jackson <isaac.jackson@penten.com>
Co-authored-by: guacs <126393040+guacs@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent 36a9cc1 commit b09bf64
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 3 deletions.
15 changes: 13 additions & 2 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
)

@classmethod
def get_field_value_coverage( # noqa: C901
def get_field_value_coverage( # noqa: C901,PLR0912
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
Expand Down Expand Up @@ -834,7 +834,18 @@ def get_field_value_coverage( # noqa: C901
)

elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
yield handle_collection_type_coverage(field_meta, origin, cls)
if not field_meta.children:
msg = "A subclass of Collection should always have children in it's field_meta"
raise ParameterException(msg)

# We actually want to use the parent in cases where the collection is the parent
# such as tuple so default to the parent if this annotation is not present in the children
child_meta = next(
(meta for meta in field_meta.children if meta.annotation == unwrapped_annotation),
field_meta,
)

yield handle_collection_type_coverage(child_meta, origin, cls)

elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
yield create_random_string(cls.__random__, min_length=1, max_length=10)
Expand Down
145 changes: 144 additions & 1 deletion tests/test_type_coverage_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@

from dataclasses import dataclass, make_dataclass
from datetime import date
from typing import Dict, FrozenSet, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Dict, FrozenSet, Iterable, List, Literal, Optional, Set, Tuple, Type, Union
from uuid import UUID

import pytest
from typing_extensions import TypedDict

from pydantic import BaseModel

from polyfactory.decorators import post_generated
from polyfactory.exceptions import ParameterException
from polyfactory.factories.dataclass_factory import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.factories.typed_dict_factory import TypedDictFactory
from polyfactory.utils.types import NoneType
from tests.test_pydantic_factory import IS_PYDANTIC_V1


def test_coverage_count() -> None:
Expand Down Expand Up @@ -227,3 +232,141 @@ class OptionalIntFactory(DataclassFactory[OptionalInt]):

assert isinstance(results[0].i, int)
assert results[1].i is None


def type_exists_at_path_any(objs: Iterable, path: List[Union[int, str]], target_type: Type) -> bool:
return any(type_exists_at_path(obj, path, target_type) for obj in objs)


def type_exists_at_path(obj: Any, path: List[Union[int, str]], target_type: Type) -> bool:
"""Determine if a type exists at a path through a given object
type_exists_at_path(obj, ["i", "*"], int)
returns true if 'obj' contains an iterable attribute called 'i' with an 'int' value
'*' is used to mean any of an iterable
type_exists_at_path(obj, ["i", 5], int)
returns true if 'obj' contains an iterable attribute called 'i' with an 'int' value at the index 5
Direct indexing is useful for checking tuples
:param obj: Object to search through
:param path: List of either indexes or attr keys to dereferrence
:param target_type: Type to match
:returns: A boolean, True if the type exists at the path, False otherwise
"""
# Handle fully dereferenced item and the end of path
if len(path) == 0:
return type(obj) == target_type

if path[0] == "*":
if not isinstance(obj, Iterable):
return False
for piece in obj:
if type_exists_at_path(piece, path[1:], target_type):
return True
return False

item, success = get_or_index(obj, path[0])
if not success:
return False

return type_exists_at_path(item, path[1:], target_type)


def get_or_index(obj: Any, idx: Union[int, str]) -> Tuple[Any, bool]:
if isinstance(idx, str):
if idx not in dir(obj):
return None, False

return getattr(obj, idx), True
if len(obj) < idx:
return None, False

return obj[idx], True


def test_coverage_optional_list() -> None:
@dataclass
class OptionalIntList:
i: Optional[List[int]]

class OptionalIntFactory(DataclassFactory[OptionalIntList]):
__model__ = OptionalIntList

results = list(OptionalIntFactory.coverage())

assert type_exists_at_path_any(results, ["i"], list)
assert type_exists_at_path_any(results, ["i", "*"], int)
assert type_exists_at_path_any(results, ["i"], NoneType)


def test_optional_lists() -> None:
class Model(BaseModel):
just_a_list: List[int]
optional_list: Optional[List[int]]
optional_nested_list: Optional[List[List[List[int]]]]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["just_a_list"], list)
assert type_exists_at_path_any(results, ["optional_list"], list)
assert type_exists_at_path_any(results, ["optional_list"], NoneType)
assert type_exists_at_path_any(results, ["optional_nested_list"], NoneType)
assert type_exists_at_path_any(results, ["optional_nested_list", "*"], list)
assert type_exists_at_path_any(results, ["optional_nested_list", "*", "*"], list)
assert type_exists_at_path_any(results, ["optional_nested_list", "*", "*", "*"], int)


def test_tuple_types() -> None:
class Model(BaseModel):
tii: Tuple[int, int]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["tii"], tuple)
assert type_exists_at_path_any(results, ["tii", 0], int)
assert type_exists_at_path_any(results, ["tii", 1], int)


def test_hetero_tuple_types() -> None:
class Model(BaseModel):
tis: Tuple[int, str]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["tis"], tuple)
assert type_exists_at_path_any(results, ["tis", 0], int)
assert type_exists_at_path_any(results, ["tis", 1], str)


def test_optional_list_uuid() -> None:
class Model(BaseModel):
maybe_uuids: Optional[List[UUID]]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["maybe_uuids"], list)
assert type_exists_at_path_any(results, ["maybe_uuids", "*"], UUID)
assert type_exists_at_path_any(results, ["maybe_uuids"], NoneType)


def test_optional_set_uuid() -> None:
class Model(BaseModel):
maybe_uuids: Optional[Set[UUID]]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["maybe_uuids"], set)
assert type_exists_at_path_any(results, ["maybe_uuids", "*"], UUID)
assert type_exists_at_path_any(results, ["maybe_uuids"], NoneType)


@pytest.mark.skipif(
IS_PYDANTIC_V1,
reason="This should be possible but more work needs to be done",
)
def test_optional_mixed_collecions() -> None:
class Model(BaseModel):
maybe_uuids: Optional[Union[Set[UUID], List[UUID]]]

results = list(ModelFactory.create_factory(Model).coverage())
assert type_exists_at_path_any(results, ["maybe_uuids"], set)
assert type_exists_at_path_any(results, ["maybe_uuids"], list)
assert type_exists_at_path_any(results, ["maybe_uuids", "*"], UUID)
assert type_exists_at_path_any(results, ["maybe_uuids"], NoneType)

0 comments on commit b09bf64

Please sign in to comment.