Skip to content

Commit f46421f

Browse files
multimericpre-commit-ci[bot]tlambert03
authored
Fix for pydantic 1.X (#24)
* Use pydantic.v1 trick * style(pre-commit.ci): auto fixes [...] * Remove henious tuple calls * Fix merge conflict * style(pre-commit.ci): auto fixes [...] * Type checking * Add explicit __mro__ check * Fix type annotations * Reformat * Mypy labours * fix typing --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Talley Lambert <talley.lambert@gmail.com>
1 parent c874251 commit f46421f

File tree

2 files changed

+88
-24
lines changed

2 files changed

+88
-24
lines changed

src/fieldz/adapters/_pydantic.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import re
55
import sys
6-
from typing import TYPE_CHECKING, Any, Iterator, overload
6+
from typing import TYPE_CHECKING, Any, Iterator, cast, overload
77

88
from fieldz._types import (
99
Constraints,
@@ -16,6 +16,7 @@
1616
if TYPE_CHECKING:
1717
import pydantic
1818
import pydantic.fields
19+
from pydantic.v1 import BaseModel as PydanticV1BaseModel
1920
from typing_extensions import TypeGuard
2021

2122

@@ -30,9 +31,12 @@ def is_pydantic_model(obj: object) -> TypeGuard[pydantic.BaseModel]: ...
3031
def is_pydantic_model(obj: Any) -> bool:
3132
"""Return True if obj is a pydantic.BaseModel subclass or instance."""
3233
pydantic = sys.modules.get("pydantic", None)
34+
pydantic_v1 = sys.modules.get("pydantic.v1", None)
3335
cls = obj if isinstance(obj, type) else type(obj)
3436
if pydantic is not None and issubclass(cls, pydantic.BaseModel):
3537
return True
38+
elif pydantic_v1 is not None and issubclass(cls, pydantic_v1.BaseModel):
39+
return True
3640
elif hasattr(cls, "__pydantic_model__") or hasattr(cls, "__pydantic_fields__"):
3741
return True
3842
return False
@@ -66,11 +70,14 @@ def replace(obj: pydantic.BaseModel, /, **changes: Any) -> Any:
6670
return obj.copy(update=changes)
6771

6872

69-
def _fields_v1(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[Field]:
70-
from pydantic.fields import Undefined # type: ignore
73+
def _fields_v1(obj: PydanticV1BaseModel | type[PydanticV1BaseModel]) -> Iterator[Field]:
74+
try:
75+
from pydantic.v1.fields import Undefined
76+
except ImportError:
77+
from pydantic.fields import Undefined # type: ignore
7178

72-
annotations = getattr(obj, "__annotations__", {})
73-
for name, modelfield in obj.__fields__.items(): # type: ignore
79+
annotations = {key: field.annotation for key, field in obj.__fields__.items()}
80+
for name, modelfield in obj.__fields__.items():
7481
factory = (
7582
modelfield.default_factory
7683
if callable(modelfield.default_factory)
@@ -83,7 +90,7 @@ def _fields_v1(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
8390
else modelfield.default
8491
)
8592
# backport from pydantic2
86-
_extra_dict = modelfield.field_info.extra.copy() # type: ignore
93+
_extra_dict = modelfield.field_info.extra.copy()
8794
if "json_schema_extra" in _extra_dict:
8895
_extra_dict.update(_extra_dict.pop("json_schema_extra"))
8996

@@ -93,14 +100,16 @@ def _fields_v1(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
93100
default=default,
94101
default_factory=(factory if callable(factory) else Field.MISSING),
95102
native_field=modelfield,
96-
description=modelfield.field_info.description, # type: ignore
103+
description=modelfield.field_info.description,
97104
metadata=_extra_dict,
98105
constraints=_constraints_v1(modelfield),
99106
)
100107

101108

102109
def _constraints_v1(modelfield: Any) -> Constraints | None:
103110
kwargs = {}
111+
if not hasattr(modelfield.type_, "__mro__"):
112+
return None
104113
# check if the type is a pydantic constrained type
105114
for subt in modelfield.type_.__mro__:
106115
if (subt.__module__ or "").startswith("pydantic.types"):
@@ -163,11 +172,18 @@ def _fields_v2(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> Iterator[F
163172
)
164173

165174

166-
def fields(obj: pydantic.BaseModel | type[pydantic.BaseModel]) -> tuple[Field, ...]:
175+
def fields(
176+
obj: pydantic.BaseModel
177+
| PydanticV1BaseModel
178+
| type[pydantic.BaseModel]
179+
| type[PydanticV1BaseModel],
180+
) -> tuple[Field, ...]:
167181
if hasattr(obj, "model_fields") or hasattr(obj, "__pydantic_fields__"):
182+
obj = cast("pydantic.BaseModel | type[pydantic.BaseModel]", obj)
168183
return tuple(_fields_v2(obj))
169184
if hasattr(obj, "__pydantic_model__"):
170185
obj = obj.__pydantic_model__ # v1 dataclass
186+
obj = cast("PydanticV1BaseModel | type[PydanticV1BaseModel]", obj)
171187
return tuple(_fields_v1(obj))
172188

173189

tests/test_fieldz.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
from typing import Callable, List, NamedTuple, Optional, TypedDict
2+
from typing import Any, Callable, List, NamedTuple, Optional, TypedDict
33

44
import pytest
55

@@ -15,6 +15,7 @@ class Model:
1515
c: float = 0.0
1616
d: bool = False
1717
e: List[int] = dataclasses.field(default_factory=list)
18+
f: Any = ()
1819

1920
return Model
2021

@@ -26,6 +27,35 @@ class Model(NamedTuple):
2627
c: float = 0.0
2728
d: bool = False
2829
e: List[int] = [] # noqa
30+
f: Any = ()
31+
32+
return Model
33+
34+
35+
def _pydantic_v1_model_str() -> type:
36+
from pydantic.v1 import BaseModel, Field
37+
38+
class Model(BaseModel):
39+
a: "int" = 0
40+
b: "Optional[str]" = None
41+
c: "float" = 0.0
42+
d: "bool" = False
43+
e: "List[int]" = Field(default_factory=list)
44+
f: "Any" = ()
45+
46+
return Model
47+
48+
49+
def _pydantic_v1_model() -> type:
50+
from pydantic.v1 import BaseModel, Field
51+
52+
class Model(BaseModel):
53+
a: int = 0
54+
b: Optional[str] = None
55+
c: float = 0.0
56+
d: bool = False
57+
e: List[int] = Field(default_factory=list)
58+
f: Any = ()
2959

3060
return Model
3161

@@ -39,6 +69,7 @@ class Model(BaseModel):
3969
c: float = 0.0
4070
d: bool = False
4171
e: List[int] = Field(default_factory=list)
72+
f: Any = ()
4273

4374
return Model
4475

@@ -53,6 +84,7 @@ class Model:
5384
c: float = 0.0
5485
d: bool = False
5586
e: List[int] = dataclasses.field(default_factory=list)
87+
f: Any = ()
5688

5789
return Model
5890

@@ -67,6 +99,7 @@ class Model(SQLModel):
6799
c: float = 0.0
68100
d: bool = False
69101
e: List[int] = Field(default_factory=list)
102+
f: Any = ()
70103

71104
return Model
72105

@@ -81,6 +114,7 @@ class Model:
81114
c: float = 0.0
82115
d: bool = False
83116
e: List[int] = attr.field(default=attr.Factory(list))
117+
f: Any = ()
84118

85119
return Model
86120

@@ -94,6 +128,7 @@ class Model(msgspec.Struct):
94128
c: float = 0.0
95129
d: bool = False
96130
e: List[int] = msgspec.field(default_factory=list)
131+
f: Any = ()
97132

98133
return Model
99134

@@ -108,6 +143,7 @@ class Model:
108143
c: float = 0.0
109144
d: bool = False
110145
e: List[int] = [] # noqa
146+
f: Any = ()
111147

112148
return Model
113149

@@ -121,6 +157,7 @@ class Model(models.Model):
121157
c: float = models.FloatField(default=0.0)
122158
d: bool = models.BooleanField(default=False)
123159
e: List[int] = models.JSONField(default=list)
160+
f: Any = ()
124161

125162
return Model
126163

@@ -132,6 +169,8 @@ class Model(models.Model):
132169
_named_tuple,
133170
_dataclassy_model,
134171
_pydantic_model,
172+
_pydantic_v1_model,
173+
_pydantic_v1_model_str,
135174
_attrs_model,
136175
_msgspec_model,
137176
_sqlmodel,
@@ -142,30 +181,39 @@ class Model(models.Model):
142181
def test_adapters(builder: Callable) -> None:
143182
model = builder()
144183
obj = model()
145-
assert asdict(obj) == {"a": 0, "b": None, "c": 0.0, "d": False, "e": []}
146-
assert astuple(obj) == (0, None, 0.0, False, [])
184+
assert asdict(obj) == {"a": 0, "b": None, "c": 0.0, "d": False, "e": [], "f": ()}
185+
assert astuple(obj) == (0, None, 0.0, False, [], ())
147186
fields_ = fields(obj)
148-
assert [f.name for f in fields_] == ["a", "b", "c", "d", "e"]
149-
assert [f.type for f in fields_] == [int, Optional[str], float, bool, List[int]]
150-
assert [f.frozen for f in fields_] == [False] * 5
187+
assert [f.name for f in fields_] == ["a", "b", "c", "d", "e", "f"]
188+
assert [f.type for f in fields_] == [
189+
int,
190+
Optional[str],
191+
float,
192+
bool,
193+
List[int],
194+
Any,
195+
]
196+
assert [f.frozen for f in fields_] == [False] * 6
151197
if is_named_tuple(obj):
152-
assert [f.default for f in fields_] == [0, None, 0.0, False, []]
198+
assert [f.default for f in fields_] == [0, None, 0.0, False, [], ()]
153199
else:
154200
# namedtuples don't have default_factory
155-
assert [f.default for f in fields_] == [
156-
0,
157-
None,
158-
0.0,
159-
False,
160-
Field.MISSING,
161-
]
201+
assert [f.default for f in fields_] == [0, None, 0.0, False, Field.MISSING, ()]
162202
assert [f.default_factory for f in fields_] == [
163203
*[Field.MISSING] * 4,
164204
list,
205+
Field.MISSING,
165206
]
166207

167-
obj2 = replace(obj, a=1, b="b2", c=1.0, d=True, e=[1, 2, 3])
168-
assert asdict(obj2) == {"a": 1, "b": "b2", "c": 1.0, "d": True, "e": [1, 2, 3]}
208+
obj2 = replace(obj, a=1, b="b2", c=1.0, d=True, e=[1, 2, 3], f={})
209+
assert asdict(obj2) == {
210+
"a": 1,
211+
"b": "b2",
212+
"c": 1.0,
213+
"d": True,
214+
"e": [1, 2, 3],
215+
"f": {},
216+
}
169217

170218
p = params(obj)
171219
assert p.eq is True

0 commit comments

Comments
 (0)