1
+ import ipaddress
2
+ import uuid
3
+ from datetime import date , datetime , time , timedelta
4
+ from decimal import Decimal
5
+ from enum import Enum
6
+ from pathlib import Path
1
7
from types import NoneType
2
8
from typing import (
3
9
TYPE_CHECKING ,
6
12
Dict ,
7
13
ForwardRef ,
8
14
Optional ,
15
+ Sequence ,
9
16
Type ,
10
17
TypeVar ,
11
18
Union ,
19
+ cast ,
12
20
get_args ,
13
21
get_origin ,
14
22
)
15
23
16
24
from pydantic import VERSION as PYDANTIC_VERSION
25
+ from sqlalchemy import (
26
+ Boolean ,
27
+ Column ,
28
+ Date ,
29
+ DateTime ,
30
+ Float ,
31
+ ForeignKey ,
32
+ Integer ,
33
+ Interval ,
34
+ Numeric ,
35
+ )
36
+ from sqlalchemy import Enum as sa_Enum
37
+ from sqlalchemy .sql .sqltypes import LargeBinary , Time
38
+
39
+ from .sql .sqltypes import GUID , AutoString
17
40
18
41
IS_PYDANTIC_V2 = int (PYDANTIC_VERSION .split ("." )[0 ]) >= 2
19
42
20
43
21
44
if IS_PYDANTIC_V2 :
22
45
from pydantic import ConfigDict as PydanticModelConfig
46
+ from pydantic ._internal ._fields import PydanticMetadata
47
+ from pydantic ._internal ._model_construction import ModelMetaclass
23
48
from pydantic_core import PydanticUndefined as PydanticUndefined # noqa
24
49
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
25
50
else :
26
51
from pydantic import BaseConfig as PydanticModelConfig
27
- from pydantic .fields import ModelField # noqa
28
- from pydantic .fields import Undefined as PydanticUndefined , UndefinedType as PydanticUndefinedType , SHAPE_SINGLETON # noqa
52
+ from pydantic .fields import SHAPE_SINGLETON , ModelField
53
+ from pydantic .fields import Undefined as PydanticUndefined # noqa
54
+ from pydantic .fields import UndefinedType as PydanticUndefinedType
55
+ from pydantic .main import ModelMetaclass as ModelMetaclass
29
56
from pydantic .typing import resolve_annotations
30
57
31
58
if TYPE_CHECKING :
37
64
InstanceOrType = Union [T , Type [T ]]
38
65
39
66
if IS_PYDANTIC_V2 :
67
+
40
68
class SQLModelConfig (PydanticModelConfig , total = False ):
41
69
table : Optional [bool ]
42
70
registry : Optional [Any ]
43
71
44
72
else :
73
+
45
74
class SQLModelConfig (PydanticModelConfig ):
46
75
table : Optional [bool ] = None
47
76
registry : Optional [Any ] = None
@@ -78,14 +107,14 @@ def set_config_value(
78
107
79
108
def get_model_fields (model : InstanceOrType ["SQLModel" ]) -> Dict [str , "FieldInfo" ]:
80
109
if IS_PYDANTIC_V2 :
81
- return model .model_fields # type: ignore
110
+ return model .model_fields # type: ignore
82
111
else :
83
112
return model .__fields__ # type: ignore
84
113
85
114
86
115
def get_fields_set (model : InstanceOrType ["SQLModel" ]) -> set [str ]:
87
116
if IS_PYDANTIC_V2 :
88
- return model .__pydantic_fields_set__ # type: ignore
117
+ return model .__pydantic_fields_set__ # type: ignore
89
118
else :
90
119
return model .__fields_set__ # type: ignore
91
120
@@ -115,21 +144,36 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
115
144
)
116
145
117
146
118
- def is_table (class_dict : dict [str , Any ]) -> bool :
147
+ def class_dict_is_table (
148
+ class_dict : dict [str , Any ], class_kwargs : dict [str , Any ]
149
+ ) -> bool :
119
150
config : SQLModelConfig = {}
120
151
if IS_PYDANTIC_V2 :
121
152
config = class_dict .get ("model_config" , {})
122
153
else :
123
154
config = class_dict .get ("__config__" , {})
124
155
config_table = config .get ("table" , PydanticUndefined )
125
156
if config_table is not PydanticUndefined :
126
- return config_table # type: ignore
127
- kw_table = class_dict .get ("table" , PydanticUndefined )
157
+ return config_table # type: ignore
158
+ kw_table = class_kwargs .get ("table" , PydanticUndefined )
128
159
if kw_table is not PydanticUndefined :
129
- return kw_table # type: ignore
160
+ return kw_table # type: ignore
130
161
return False
131
162
132
163
164
+ def cls_is_table (cls : Type ) -> bool :
165
+ if IS_PYDANTIC_V2 :
166
+ config = getattr (cls , "model_config" , None )
167
+ if not config :
168
+ return False
169
+ return config .get ("table" , False )
170
+ else :
171
+ config = getattr (cls , "__config__" , None )
172
+ if not config :
173
+ return False
174
+ return getattr (config , "table" , False )
175
+
176
+
133
177
def get_relationship_to (
134
178
name : str ,
135
179
rel_info : "RelationshipInfo" ,
@@ -186,17 +230,15 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any])
186
230
value .default in (PydanticUndefined , Ellipsis )
187
231
) and value .default_factory is None :
188
232
# So we can check for nullable
189
- value .original_default = value .default
190
233
value .default = None
191
234
192
235
193
- def is_field_noneable (field : "FieldInfo" ) -> bool :
236
+ def _is_field_noneable (field : "FieldInfo" ) -> bool :
194
237
if IS_PYDANTIC_V2 :
195
238
if getattr (field , "nullable" , PydanticUndefined ) is not PydanticUndefined :
196
- return field .nullable # type: ignore
239
+ return field .nullable # type: ignore
197
240
if not field .is_required ():
198
- default = getattr (field , "original_default" , field .default )
199
- if default is PydanticUndefined :
241
+ if field .default is PydanticUndefined :
200
242
return False
201
243
if field .annotation is None or field .annotation is NoneType :
202
244
return True
@@ -212,4 +254,163 @@ def is_field_noneable(field: "FieldInfo") -> bool:
212
254
return field .allow_none and (
213
255
field .shape != SHAPE_SINGLETON or not field .sub_fields
214
256
)
215
- return False
257
+ return field .allow_none
258
+
259
+
260
+ def get_sqlalchemy_type (field : Any ) -> Any :
261
+ if IS_PYDANTIC_V2 :
262
+ field_info = field
263
+ else :
264
+ field_info = field .field_info
265
+ sa_type = getattr (field_info , "sa_type" , PydanticUndefined ) # noqa: B009
266
+ if sa_type is not PydanticUndefined :
267
+ return sa_type
268
+
269
+ type_ = get_type_from_field (field )
270
+ metadata = get_field_metadata (field )
271
+
272
+ # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
273
+ if issubclass (type_ , Enum ):
274
+ return sa_Enum (type_ )
275
+ if issubclass (type_ , str ):
276
+ max_length = getattr (metadata , "max_length" , None )
277
+ if max_length :
278
+ return AutoString (length = max_length )
279
+ return AutoString
280
+ if issubclass (type_ , float ):
281
+ return Float
282
+ if issubclass (type_ , bool ):
283
+ return Boolean
284
+ if issubclass (type_ , int ):
285
+ return Integer
286
+ if issubclass (type_ , datetime ):
287
+ return DateTime
288
+ if issubclass (type_ , date ):
289
+ return Date
290
+ if issubclass (type_ , timedelta ):
291
+ return Interval
292
+ if issubclass (type_ , time ):
293
+ return Time
294
+ if issubclass (type_ , bytes ):
295
+ return LargeBinary
296
+ if issubclass (type_ , Decimal ):
297
+ return Numeric (
298
+ precision = getattr (metadata , "max_digits" , None ),
299
+ scale = getattr (metadata , "decimal_places" , None ),
300
+ )
301
+ if issubclass (type_ , ipaddress .IPv4Address ):
302
+ return AutoString
303
+ if issubclass (type_ , ipaddress .IPv4Network ):
304
+ return AutoString
305
+ if issubclass (type_ , ipaddress .IPv6Address ):
306
+ return AutoString
307
+ if issubclass (type_ , ipaddress .IPv6Network ):
308
+ return AutoString
309
+ if issubclass (type_ , Path ):
310
+ return AutoString
311
+ if issubclass (type_ , uuid .UUID ):
312
+ return GUID
313
+ raise ValueError (f"{ type_ } has no matching SQLAlchemy type" )
314
+
315
+
316
+ def get_type_from_field (field : Any ) -> type :
317
+ if IS_PYDANTIC_V2 :
318
+ type_ : type | None = field .annotation
319
+ # Resolve Optional fields
320
+ if type_ is None :
321
+ raise ValueError ("Missing field type" )
322
+ origin = get_origin (type_ )
323
+ if origin is None :
324
+ return type_
325
+ if origin is Union :
326
+ bases = get_args (type_ )
327
+ if len (bases ) > 2 :
328
+ raise ValueError (
329
+ "Cannot have a (non-optional) union as a SQL alchemy field"
330
+ )
331
+ # Non optional unions are not allowed
332
+ if bases [0 ] is not NoneType and bases [1 ] is not NoneType :
333
+ raise ValueError (
334
+ "Cannot have a (non-optional) union as a SQL alchemy field"
335
+ )
336
+ # Optional unions are allowed
337
+ return bases [0 ] if bases [0 ] is not NoneType else bases [1 ]
338
+ return origin
339
+ else :
340
+ if isinstance (field .type_ , type ) and field .shape == SHAPE_SINGLETON :
341
+ return field .type_
342
+ raise ValueError (f"The field { field .name } has no matching SQLAlchemy type" )
343
+
344
+
345
+ class FakeMetadata :
346
+ max_length : Optional [int ] = None
347
+ max_digits : Optional [int ] = None
348
+ decimal_places : Optional [int ] = None
349
+
350
+
351
+ def get_field_metadata (field : Any ) -> Any :
352
+ if IS_PYDANTIC_V2 :
353
+ for meta in field .metadata :
354
+ if isinstance (meta , PydanticMetadata ):
355
+ return meta
356
+ return FakeMetadata ()
357
+ else :
358
+ metadata = FakeMetadata ()
359
+ metadata .max_length = field .field_info .max_length
360
+ metadata .max_digits = getattr (field .type_ , "max_digits" , None )
361
+ metadata .decimal_places = getattr (field .type_ , "decimal_places" , None )
362
+ return metadata
363
+
364
+
365
+ def get_column_from_field (field : Any ) -> Column : # type: ignore
366
+ if IS_PYDANTIC_V2 :
367
+ field_info = field
368
+ else :
369
+ field_info = field .field_info
370
+ sa_column = getattr (field_info , "sa_column" , PydanticUndefined )
371
+ if isinstance (sa_column , Column ):
372
+ return sa_column
373
+ sa_type = get_sqlalchemy_type (field )
374
+ primary_key = getattr (field_info , "primary_key" , PydanticUndefined )
375
+ if primary_key is PydanticUndefined :
376
+ primary_key = False
377
+ index = getattr (field_info , "index" , PydanticUndefined )
378
+ if index is PydanticUndefined :
379
+ index = False
380
+ nullable = not primary_key and _is_field_noneable (field )
381
+ # Override derived nullability if the nullable property is set explicitly
382
+ # on the field
383
+ field_nullable = getattr (field_info , "nullable" , PydanticUndefined ) # noqa: B009
384
+ if field_nullable is not PydanticUndefined :
385
+ assert not isinstance (field_nullable , PydanticUndefinedType )
386
+ nullable = field_nullable
387
+ args = []
388
+ foreign_key = getattr (field_info , "foreign_key" , PydanticUndefined )
389
+ if foreign_key is PydanticUndefined :
390
+ foreign_key = None
391
+ unique = getattr (field_info , "unique" , PydanticUndefined )
392
+ if unique is PydanticUndefined :
393
+ unique = False
394
+ if foreign_key :
395
+ assert isinstance (foreign_key , str )
396
+ args .append (ForeignKey (foreign_key ))
397
+ kwargs = {
398
+ "primary_key" : primary_key ,
399
+ "nullable" : nullable ,
400
+ "index" : index ,
401
+ "unique" : unique ,
402
+ }
403
+ sa_default = PydanticUndefined
404
+ if field_info .default_factory :
405
+ sa_default = field_info .default_factory
406
+ elif field_info .default is not PydanticUndefined :
407
+ sa_default = field_info .default
408
+ if sa_default is not PydanticUndefined :
409
+ kwargs ["default" ] = sa_default
410
+ sa_column_args = getattr (field_info , "sa_column_args" , PydanticUndefined )
411
+ if sa_column_args is not PydanticUndefined :
412
+ args .extend (list (cast (Sequence [Any ], sa_column_args )))
413
+ sa_column_kwargs = getattr (field_info , "sa_column_kwargs" , PydanticUndefined )
414
+ if sa_column_kwargs is not PydanticUndefined :
415
+ kwargs .update (cast (Dict [Any , Any ], sa_column_kwargs ))
416
+ return Column (sa_type , * args , ** kwargs ) # type: ignore
0 commit comments