Skip to content

Commit 34ba49d

Browse files
authored
pydantic constraints (#3)
* pydantic constraints * add hint * no cover
1 parent 0ef815a commit 34ba49d

File tree

3 files changed

+97
-24
lines changed

3 files changed

+97
-24
lines changed

src/dataclass_compat/_types.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -138,41 +138,42 @@ def _parse_annotated_hint(hint: Any) -> tuple[dict, dict]:
138138
return kwargs, constraints
139139

140140

141+
# At the moment, all of our constraint names match msgspec.Meta attributes
142+
# (we are a superset of msgspec.Meta)
143+
CONSTRAINT_NAMES = {f.name for f in dataclasses.fields(Constraints)}
144+
FIELD_NAMES = {f.name for f in dataclasses.fields(Field)}
145+
146+
141147
def _parse_annotatedtypes_meta(metadata: list[Any]) -> dict[str, Any]:
142148
"""Extract constraints from annotated_types metadata."""
143149
if TYPE_CHECKING:
144-
import annotated_types
150+
import annotated_types as at
145151
else:
146-
annotated_types = sys.modules.get("annotated_types")
147-
if annotated_types is None:
152+
at = sys.modules.get("annotated_types")
153+
if at is None:
148154
return {} # pragma: no cover
149155

150156
a_kwargs = {}
151157
for item in metadata:
152158
# annotated_types >= 0.3.0 is supported
153-
if isinstance(item, annotated_types.BaseMetadata):
154-
a_kwargs.update(dataclasses.asdict(item)) # type: ignore
155-
elif isinstance(item, annotated_types.GroupedMetadata):
156-
for i in item:
157-
a_kwargs.update(dataclasses.asdict(i)) # type: ignore
158-
# annotated types calls the value of a Predicate "func"
159-
if "func" in a_kwargs:
160-
a_kwargs["predicate"] = a_kwargs.pop("func")
161-
162-
# these were changed in v0.4.0
163-
if "min_inclusive" in a_kwargs: # pragma: no cover
164-
a_kwargs["min_length"] = a_kwargs.pop("min_inclusive")
165-
if "max_exclusive" in a_kwargs: # pragma: no cover
166-
a_kwargs["max_length"] = a_kwargs.pop("max_exclusive") - 1
159+
if isinstance(item, (at.BaseMetadata, at.GroupedMetadata)):
160+
try:
161+
values = dataclasses.asdict(item) # type: ignore
162+
except TypeError: # pragma: no cover
163+
continue
164+
a_kwargs.update({k: v for k, v in values.items() if k in CONSTRAINT_NAMES})
165+
# annotated types calls the value of a Predicate "func"
166+
if "func" in values:
167+
a_kwargs["predicate"] = values["func"]
168+
169+
# these were changed in v0.4.0
170+
if "min_inclusive" in values: # pragma: no cover
171+
a_kwargs["min_length"] = values["min_inclusive"]
172+
if "max_exclusive" in values: # pragma: no cover
173+
a_kwargs["max_length"] = values["max_exclusive"] - 1
167174
return a_kwargs
168175

169176

170-
# At the moment, all of our constraint names match msgspec.Meta attributes
171-
# (we are a superset of msgspec.Meta)
172-
CONSTRAINT_NAMES = {f.name for f in dataclasses.fields(Constraints)}
173-
FIELD_NAMES = {f.name for f in dataclasses.fields(Field)}
174-
175-
176177
def _parse_msgspec_meta(metadata: list[Any]) -> tuple[dict, dict]:
177178
"""Extract constraints from msgspec.Meta metadata."""
178179
if TYPE_CHECKING:

src/dataclass_compat/adapters/_pydantic.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

33
import dataclasses
4+
import re
45
import sys
56
from typing import TYPE_CHECKING, Any, Iterator, overload
67

7-
from dataclass_compat._types import DataclassParams, Field
8+
from dataclass_compat._types import (
9+
Constraints,
10+
DataclassParams,
11+
Field,
12+
_is_annotated_type,
13+
_parse_annotatedtypes_meta,
14+
)
815

916
if TYPE_CHECKING:
1017
import pydantic
18+
import pydantic.fields
1119
from typing_extensions import TypeGuard
1220

1321

@@ -89,9 +97,34 @@ def _fields_v1(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
8997
native_field=modelfield,
9098
description=modelfield.field_info.description, # type: ignore
9199
metadata=_extra_dict,
100+
constraints=_constraints_v1(modelfield),
92101
)
93102

94103

104+
def _constraints_v1(modelfield: Any) -> Constraints | None:
105+
kwargs = {}
106+
# check if the type is a pydantic constrained type
107+
for subt in modelfield.type_.__mro__:
108+
if (subt.__module__ or "").startswith("pydantic.types"):
109+
keys = (
110+
"gt",
111+
"ge",
112+
"lt",
113+
"le",
114+
"multiple_of",
115+
"max_digits",
116+
"decimal_places",
117+
"min_length",
118+
"max_length",
119+
)
120+
kwargs.update({key: getattr(modelfield.type_, key, None) for key in keys})
121+
if regex := getattr(modelfield.type_, "regex", None):
122+
if isinstance(regex, re.Pattern):
123+
regex = regex.pattern
124+
kwargs["pattern"] = regex
125+
return Constraints(**kwargs) if kwargs else None
126+
127+
95128
def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[Field]:
96129
from pydantic_core import PydanticUndefined
97130

@@ -100,6 +133,7 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
100133
else:
101134
_fields = obj.model_fields.items()
102135

136+
annotations = getattr(obj, "__annotations__", {})
103137
for name, finfo in _fields:
104138
factory = (
105139
finfo.default_factory if callable(finfo.default_factory) else Field.MISSING
@@ -110,6 +144,14 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
110144
else finfo.default
111145
)
112146
extra = finfo.json_schema_extra
147+
148+
annotated_type = annotations.get(name)
149+
if not _is_annotated_type(annotated_type):
150+
annotated_type = None
151+
152+
c = _parse_annotatedtypes_meta(finfo.metadata)
153+
constraints = Constraints(**c) if c else None
154+
113155
yield Field(
114156
name=name,
115157
type=finfo.annotation,
@@ -118,6 +160,8 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
118160
native_field=finfo,
119161
description=finfo.description,
120162
metadata=extra if isinstance(extra, dict) else {},
163+
annotated_type=annotated_type,
164+
constraints=constraints,
121165
)
122166

123167

tests/test_pydantic.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import annotated_types as at
2+
from dataclass_compat import fields
3+
from pydantic import BaseModel, Field, conint, constr
4+
from typing_extensions import Annotated
5+
6+
try:
7+
constr_ = constr(regex=r"^[a-z]+$")
8+
except TypeError:
9+
constr_ = constr(pattern=r"^[a-z]+$") # type: ignore
10+
11+
12+
def test_pydantic_constraints() -> None:
13+
class M(BaseModel):
14+
a: int = Field(default=50, ge=42, le=100)
15+
b: Annotated[int, Field(ge=42, le=100)] = 50
16+
c: Annotated[int, at.Ge(42), at.Le(100)] = 50
17+
d: conint(ge=42, le=100) = 50 # type: ignore
18+
e: constr_ = "abc" # type: ignore
19+
20+
for f in fields(M):
21+
assert f.constraints
22+
if f.name == "e":
23+
assert f.constraints.pattern == r"^[a-z]+$"
24+
else:
25+
assert f.constraints.ge == 42
26+
assert f.constraints.le == 100
27+
assert f.default == 50
28+
assert f.default == 50

0 commit comments

Comments
 (0)