Skip to content

Commit e2d4d1f

Browse files
Make most tests succeed
Only need to fix OPEN API things I think
1 parent 4a4161a commit e2d4d1f

File tree

7 files changed

+314
-155
lines changed

7 files changed

+314
-155
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pillow = "^9.3.0"
4646
cairosvg = "^2.5.2"
4747
mdx-include = "^1.4.1"
4848
coverage = {extras = ["toml"], version = ">=6.2,<8.0"}
49-
fastapi = "^0.68.1"
49+
fastapi = "^0.100.0"
5050
requests = "^2.26.0"
5151
ruff = "^0.1.2"
5252

sqlmodel/compat.py

Lines changed: 215 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import ipaddress
2+
import uuid
3+
from datetime import date, datetime, time, timedelta
4+
from decimal import Decimal
5+
from enum import Enum
6+
from pathlib import Path
17
from types import NoneType
28
from typing import (
39
TYPE_CHECKING,
@@ -6,26 +12,47 @@
612
Dict,
713
ForwardRef,
814
Optional,
15+
Sequence,
916
Type,
1017
TypeVar,
1118
Union,
19+
cast,
1220
get_args,
1321
get_origin,
1422
)
1523

1624
from pydantic import VERSION as PYDANTIC_VERSION
25+
from sqlalchemy import (
26+
Boolean,
27+
Column,
28+
Date,
29+
DateTime,
30+
Float,
31+
ForeignKey,
32+
Integer,
33+
Interval,
34+
Numeric,
35+
)
36+
from sqlalchemy import Enum as sa_Enum
37+
from sqlalchemy.sql.sqltypes import LargeBinary, Time
38+
39+
from .sql.sqltypes import GUID, AutoString
1740

1841
IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2
1942

2043

2144
if IS_PYDANTIC_V2:
2245
from pydantic import ConfigDict as PydanticModelConfig
46+
from pydantic._internal._fields import PydanticMetadata
47+
from pydantic._internal._model_construction import ModelMetaclass
2348
from pydantic_core import PydanticUndefined as PydanticUndefined # noqa
2449
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
2550
else:
2651
from pydantic import BaseConfig as PydanticModelConfig
27-
from pydantic.fields import ModelField # noqa
28-
from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON # noqa
52+
from pydantic.fields import SHAPE_SINGLETON, ModelField
53+
from pydantic.fields import Undefined as PydanticUndefined # noqa
54+
from pydantic.fields import UndefinedType as PydanticUndefinedType
55+
from pydantic.main import ModelMetaclass as ModelMetaclass
2956
from pydantic.typing import resolve_annotations
3057

3158
if TYPE_CHECKING:
@@ -37,11 +64,13 @@
3764
InstanceOrType = Union[T, Type[T]]
3865

3966
if IS_PYDANTIC_V2:
67+
4068
class SQLModelConfig(PydanticModelConfig, total=False):
4169
table: Optional[bool]
4270
registry: Optional[Any]
4371

4472
else:
73+
4574
class SQLModelConfig(PydanticModelConfig):
4675
table: Optional[bool] = None
4776
registry: Optional[Any] = None
@@ -78,14 +107,14 @@ def set_config_value(
78107

79108
def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
80109
if IS_PYDANTIC_V2:
81-
return model.model_fields # type: ignore
110+
return model.model_fields # type: ignore
82111
else:
83112
return model.__fields__ # type: ignore
84113

85114

86115
def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]:
87116
if IS_PYDANTIC_V2:
88-
return model.__pydantic_fields_set__ # type: ignore
117+
return model.__pydantic_fields_set__ # type: ignore
89118
else:
90119
return model.__fields_set__ # type: ignore
91120

@@ -115,21 +144,36 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
115144
)
116145

117146

118-
def is_table(class_dict: dict[str, Any]) -> bool:
147+
def class_dict_is_table(
148+
class_dict: dict[str, Any], class_kwargs: dict[str, Any]
149+
) -> bool:
119150
config: SQLModelConfig = {}
120151
if IS_PYDANTIC_V2:
121152
config = class_dict.get("model_config", {})
122153
else:
123154
config = class_dict.get("__config__", {})
124155
config_table = config.get("table", PydanticUndefined)
125156
if config_table is not PydanticUndefined:
126-
return config_table # type: ignore
127-
kw_table = class_dict.get("table", PydanticUndefined)
157+
return config_table # type: ignore
158+
kw_table = class_kwargs.get("table", PydanticUndefined)
128159
if kw_table is not PydanticUndefined:
129-
return kw_table # type: ignore
160+
return kw_table # type: ignore
130161
return False
131162

132163

164+
def cls_is_table(cls: Type) -> bool:
165+
if IS_PYDANTIC_V2:
166+
config = getattr(cls, "model_config", None)
167+
if not config:
168+
return False
169+
return config.get("table", False)
170+
else:
171+
config = getattr(cls, "__config__", None)
172+
if not config:
173+
return False
174+
return getattr(config, "table", False)
175+
176+
133177
def get_relationship_to(
134178
name: str,
135179
rel_info: "RelationshipInfo",
@@ -186,17 +230,15 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any])
186230
value.default in (PydanticUndefined, Ellipsis)
187231
) and value.default_factory is None:
188232
# So we can check for nullable
189-
value.original_default = value.default
190233
value.default = None
191234

192235

193-
def is_field_noneable(field: "FieldInfo") -> bool:
236+
def _is_field_noneable(field: "FieldInfo") -> bool:
194237
if IS_PYDANTIC_V2:
195238
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
196-
return field.nullable # type: ignore
239+
return field.nullable # type: ignore
197240
if not field.is_required():
198-
default = getattr(field, "original_default", field.default)
199-
if default is PydanticUndefined:
241+
if field.default is PydanticUndefined:
200242
return False
201243
if field.annotation is None or field.annotation is NoneType:
202244
return True
@@ -212,4 +254,163 @@ def is_field_noneable(field: "FieldInfo") -> bool:
212254
return field.allow_none and (
213255
field.shape != SHAPE_SINGLETON or not field.sub_fields
214256
)
215-
return False
257+
return field.allow_none
258+
259+
260+
def get_sqlalchemy_type(field: Any) -> Any:
261+
if IS_PYDANTIC_V2:
262+
field_info = field
263+
else:
264+
field_info = field.field_info
265+
sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009
266+
if sa_type is not PydanticUndefined:
267+
return sa_type
268+
269+
type_ = get_type_from_field(field)
270+
metadata = get_field_metadata(field)
271+
272+
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
273+
if issubclass(type_, Enum):
274+
return sa_Enum(type_)
275+
if issubclass(type_, str):
276+
max_length = getattr(metadata, "max_length", None)
277+
if max_length:
278+
return AutoString(length=max_length)
279+
return AutoString
280+
if issubclass(type_, float):
281+
return Float
282+
if issubclass(type_, bool):
283+
return Boolean
284+
if issubclass(type_, int):
285+
return Integer
286+
if issubclass(type_, datetime):
287+
return DateTime
288+
if issubclass(type_, date):
289+
return Date
290+
if issubclass(type_, timedelta):
291+
return Interval
292+
if issubclass(type_, time):
293+
return Time
294+
if issubclass(type_, bytes):
295+
return LargeBinary
296+
if issubclass(type_, Decimal):
297+
return Numeric(
298+
precision=getattr(metadata, "max_digits", None),
299+
scale=getattr(metadata, "decimal_places", None),
300+
)
301+
if issubclass(type_, ipaddress.IPv4Address):
302+
return AutoString
303+
if issubclass(type_, ipaddress.IPv4Network):
304+
return AutoString
305+
if issubclass(type_, ipaddress.IPv6Address):
306+
return AutoString
307+
if issubclass(type_, ipaddress.IPv6Network):
308+
return AutoString
309+
if issubclass(type_, Path):
310+
return AutoString
311+
if issubclass(type_, uuid.UUID):
312+
return GUID
313+
raise ValueError(f"{type_} has no matching SQLAlchemy type")
314+
315+
316+
def get_type_from_field(field: Any) -> type:
317+
if IS_PYDANTIC_V2:
318+
type_: type | None = field.annotation
319+
# Resolve Optional fields
320+
if type_ is None:
321+
raise ValueError("Missing field type")
322+
origin = get_origin(type_)
323+
if origin is None:
324+
return type_
325+
if origin is Union:
326+
bases = get_args(type_)
327+
if len(bases) > 2:
328+
raise ValueError(
329+
"Cannot have a (non-optional) union as a SQL alchemy field"
330+
)
331+
# Non optional unions are not allowed
332+
if bases[0] is not NoneType and bases[1] is not NoneType:
333+
raise ValueError(
334+
"Cannot have a (non-optional) union as a SQL alchemy field"
335+
)
336+
# Optional unions are allowed
337+
return bases[0] if bases[0] is not NoneType else bases[1]
338+
return origin
339+
else:
340+
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
341+
return field.type_
342+
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
343+
344+
345+
class FakeMetadata:
346+
max_length: Optional[int] = None
347+
max_digits: Optional[int] = None
348+
decimal_places: Optional[int] = None
349+
350+
351+
def get_field_metadata(field: Any) -> Any:
352+
if IS_PYDANTIC_V2:
353+
for meta in field.metadata:
354+
if isinstance(meta, PydanticMetadata):
355+
return meta
356+
return FakeMetadata()
357+
else:
358+
metadata = FakeMetadata()
359+
metadata.max_length = field.field_info.max_length
360+
metadata.max_digits = getattr(field.type_, "max_digits", None)
361+
metadata.decimal_places = getattr(field.type_, "decimal_places", None)
362+
return metadata
363+
364+
365+
def get_column_from_field(field: Any) -> Column: # type: ignore
366+
if IS_PYDANTIC_V2:
367+
field_info = field
368+
else:
369+
field_info = field.field_info
370+
sa_column = getattr(field_info, "sa_column", PydanticUndefined)
371+
if isinstance(sa_column, Column):
372+
return sa_column
373+
sa_type = get_sqlalchemy_type(field)
374+
primary_key = getattr(field_info, "primary_key", PydanticUndefined)
375+
if primary_key is PydanticUndefined:
376+
primary_key = False
377+
index = getattr(field_info, "index", PydanticUndefined)
378+
if index is PydanticUndefined:
379+
index = False
380+
nullable = not primary_key and _is_field_noneable(field)
381+
# Override derived nullability if the nullable property is set explicitly
382+
# on the field
383+
field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009
384+
if field_nullable is not PydanticUndefined:
385+
assert not isinstance(field_nullable, PydanticUndefinedType)
386+
nullable = field_nullable
387+
args = []
388+
foreign_key = getattr(field_info, "foreign_key", PydanticUndefined)
389+
if foreign_key is PydanticUndefined:
390+
foreign_key = None
391+
unique = getattr(field_info, "unique", PydanticUndefined)
392+
if unique is PydanticUndefined:
393+
unique = False
394+
if foreign_key:
395+
assert isinstance(foreign_key, str)
396+
args.append(ForeignKey(foreign_key))
397+
kwargs = {
398+
"primary_key": primary_key,
399+
"nullable": nullable,
400+
"index": index,
401+
"unique": unique,
402+
}
403+
sa_default = PydanticUndefined
404+
if field_info.default_factory:
405+
sa_default = field_info.default_factory
406+
elif field_info.default is not PydanticUndefined:
407+
sa_default = field_info.default
408+
if sa_default is not PydanticUndefined:
409+
kwargs["default"] = sa_default
410+
sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined)
411+
if sa_column_args is not PydanticUndefined:
412+
args.extend(list(cast(Sequence[Any], sa_column_args)))
413+
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined)
414+
if sa_column_kwargs is not PydanticUndefined:
415+
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
416+
return Column(sa_type, *args, **kwargs) # type: ignore

0 commit comments

Comments
 (0)