Skip to content

Commit

Permalink
fix: add support for Json type in Pydantic factory (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs authored Jul 30, 2023
1 parent 3e45f8f commit 2733497
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
50 changes: 45 additions & 5 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from polyfactory.utils.predicates import is_optional_union, is_safe_subclass, is_union

try:
from pydantic import VERSION, BaseModel
from pydantic import VERSION, BaseModel, Json
from pydantic.fields import FieldInfo
from pydantic_core import to_json
except ImportError as e:
raise MissingDependencyException("pydantic is not installed") from e

Expand All @@ -44,14 +45,34 @@
if TYPE_CHECKING:
from random import Random

from typing_extensions import TypeGuard
from typing_extensions import NotRequired, TypeGuard

T = TypeVar("T", bound=BaseModel)


class PydanticConstraints(Constraints):
"""Metadata regarding a Pydantic type constraints, if any"""

json: NotRequired[bool]


class PydanticFieldMeta(FieldMeta):
"""Field meta subclass capable of handling pydantic ModelFields"""

def __init__(
self,
*,
name: str,
annotation: type,
random: Random | None = None,
default: Any = ...,
children: list[FieldMeta] | None = None,
constraints: PydanticConstraints | None = None,
) -> None:
super().__init__(
name=name, annotation=annotation, random=random, default=default, children=children, constraints=constraints
)

@classmethod
def from_field_info(
cls,
Expand Down Expand Up @@ -102,13 +123,23 @@ def from_field_info(
for arg in get_args(annotation)
]

if metadata := [v for v in field_info.metadata if v is not None]:
constraints = cls.parse_constraints(metadata=metadata)
metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)

constraints = cls.parse_constraints(metadata=metadata) if metadata else {}
constraints = cast(PydanticConstraints, constraints)

if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str

if is_json:
constraints["json"] = True

return PydanticFieldMeta.from_type(
annotation=annotation,
children=children,
Expand Down Expand Up @@ -245,7 +276,7 @@ def from_model_field( # pragma: no cover
annotation=annotation,
children=children or None,
default=default_value,
constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None,
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)


Expand Down Expand Up @@ -311,6 +342,15 @@ def get_model_fields(cls) -> list["FieldMeta"]:
]
return cls._fields_metadata

@classmethod
def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any:
constraints = cast(PydanticConstraints, field_meta.constraints)
if constraints.pop("json", None):
value = cls.get_field_value(field_meta)
return to_json(value)

return super().get_constrained_field_value(annotation, field_meta)

@classmethod
def build(cls, factory_use_construct: bool = False, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__
Expand Down
27 changes: 26 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys

import pytest
from pydantic import VERSION, BaseModel, Field
from pydantic import VERSION, BaseModel, Field, Json

from polyfactory.factories.pydantic_factory import ModelFactory

Expand Down Expand Up @@ -39,3 +39,28 @@ class CFactory(ModelFactory[C]):
assert isinstance(CFactory.build().c, list)
assert len(CFactory.build().c) > 0
assert isinstance(CFactory.build().c[0], (A, B))


@pytest.mark.skipif(VERSION.startswith("1"), reason="only for Pydantic v2")
def test_json_type() -> None:
class A(BaseModel):
a: Json[int]

class AFactory(ModelFactory[A]):
__model__ = A

assert isinstance(AFactory.build(), A)


@pytest.mark.skipif(VERSION.startswith("1"), reason="only for Pydantic v2")
def test_nested_json_type() -> None:
class A(BaseModel):
a: int

class B(BaseModel):
b: Json[A]

class BFactory(ModelFactory[B]):
__model__ = B

assert isinstance(BFactory.build(), B)

0 comments on commit 2733497

Please sign in to comment.