Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions src/dataclass_compat/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,41 +138,42 @@ def _parse_annotated_hint(hint: Any) -> tuple[dict, dict]:
return kwargs, constraints


# At the moment, all of our constraint names match msgspec.Meta attributes
# (we are a superset of msgspec.Meta)
CONSTRAINT_NAMES = {f.name for f in dataclasses.fields(Constraints)}
FIELD_NAMES = {f.name for f in dataclasses.fields(Field)}


def _parse_annotatedtypes_meta(metadata: list[Any]) -> dict[str, Any]:
"""Extract constraints from annotated_types metadata."""
if TYPE_CHECKING:
import annotated_types
import annotated_types as at
else:
annotated_types = sys.modules.get("annotated_types")
if annotated_types is None:
at = sys.modules.get("annotated_types")
if at is None:
return {} # pragma: no cover

a_kwargs = {}
for item in metadata:
# annotated_types >= 0.3.0 is supported
if isinstance(item, annotated_types.BaseMetadata):
a_kwargs.update(dataclasses.asdict(item)) # type: ignore
elif isinstance(item, annotated_types.GroupedMetadata):
for i in item:
a_kwargs.update(dataclasses.asdict(i)) # type: ignore
# annotated types calls the value of a Predicate "func"
if "func" in a_kwargs:
a_kwargs["predicate"] = a_kwargs.pop("func")

# these were changed in v0.4.0
if "min_inclusive" in a_kwargs: # pragma: no cover
a_kwargs["min_length"] = a_kwargs.pop("min_inclusive")
if "max_exclusive" in a_kwargs: # pragma: no cover
a_kwargs["max_length"] = a_kwargs.pop("max_exclusive") - 1
if isinstance(item, (at.BaseMetadata, at.GroupedMetadata)):
try:
values = dataclasses.asdict(item) # type: ignore
except TypeError: # pragma: no cover
continue
a_kwargs.update({k: v for k, v in values.items() if k in CONSTRAINT_NAMES})
# annotated types calls the value of a Predicate "func"
if "func" in values:
a_kwargs["predicate"] = values["func"]

# these were changed in v0.4.0
if "min_inclusive" in values: # pragma: no cover
a_kwargs["min_length"] = values["min_inclusive"]
if "max_exclusive" in values: # pragma: no cover
a_kwargs["max_length"] = values["max_exclusive"] - 1
return a_kwargs


# At the moment, all of our constraint names match msgspec.Meta attributes
# (we are a superset of msgspec.Meta)
CONSTRAINT_NAMES = {f.name for f in dataclasses.fields(Constraints)}
FIELD_NAMES = {f.name for f in dataclasses.fields(Field)}


def _parse_msgspec_meta(metadata: list[Any]) -> tuple[dict, dict]:
"""Extract constraints from msgspec.Meta metadata."""
if TYPE_CHECKING:
Expand Down
46 changes: 45 additions & 1 deletion src/dataclass_compat/adapters/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import dataclasses
import re
import sys
from typing import TYPE_CHECKING, Any, Iterator, overload

from dataclass_compat._types import DataclassParams, Field
from dataclass_compat._types import (
Constraints,
DataclassParams,
Field,
_is_annotated_type,
_parse_annotatedtypes_meta,
)

if TYPE_CHECKING:
import pydantic
import pydantic.fields
from typing_extensions import TypeGuard


Expand Down Expand Up @@ -89,9 +97,34 @@ def _fields_v1(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
native_field=modelfield,
description=modelfield.field_info.description, # type: ignore
metadata=_extra_dict,
constraints=_constraints_v1(modelfield),
)


def _constraints_v1(modelfield: Any) -> Constraints | None:
kwargs = {}
# check if the type is a pydantic constrained type
for subt in modelfield.type_.__mro__:
if (subt.__module__ or "").startswith("pydantic.types"):
keys = (
"gt",
"ge",
"lt",
"le",
"multiple_of",
"max_digits",
"decimal_places",
"min_length",
"max_length",
)
kwargs.update({key: getattr(modelfield.type_, key, None) for key in keys})
if regex := getattr(modelfield.type_, "regex", None):
if isinstance(regex, re.Pattern):
regex = regex.pattern
kwargs["pattern"] = regex
return Constraints(**kwargs) if kwargs else None


def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[Field]:
from pydantic_core import PydanticUndefined

Expand All @@ -100,6 +133,7 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
else:
_fields = obj.model_fields.items()

annotations = getattr(obj, "__annotations__", {})
for name, finfo in _fields:
factory = (
finfo.default_factory if callable(finfo.default_factory) else Field.MISSING
Expand All @@ -110,6 +144,14 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
else finfo.default
)
extra = finfo.json_schema_extra

annotated_type = annotations.get(name)
if not _is_annotated_type(annotated_type):
annotated_type = None

c = _parse_annotatedtypes_meta(finfo.metadata)
constraints = Constraints(**c) if c else None

yield Field(
name=name,
type=finfo.annotation,
Expand All @@ -118,6 +160,8 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
native_field=finfo,
description=finfo.description,
metadata=extra if isinstance(extra, dict) else {},
annotated_type=annotated_type,
constraints=constraints,
)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import annotated_types as at
from dataclass_compat import fields
from pydantic import BaseModel, Field, conint, constr
from typing_extensions import Annotated

try:
constr_ = constr(regex=r"^[a-z]+$")
except TypeError:
constr_ = constr(pattern=r"^[a-z]+$") # type: ignore


def test_pydantic_constraints() -> None:
class M(BaseModel):
a: int = Field(default=50, ge=42, le=100)
b: Annotated[int, Field(ge=42, le=100)] = 50
c: Annotated[int, at.Ge(42), at.Le(100)] = 50
d: conint(ge=42, le=100) = 50 # type: ignore
e: constr_ = "abc" # type: ignore

for f in fields(M):
assert f.constraints
if f.name == "e":
assert f.constraints.pattern == r"^[a-z]+$"
else:
assert f.constraints.ge == 42
assert f.constraints.le == 100
assert f.default == 50
assert f.default == 50