Skip to content

Commit a0d1fba

Browse files
Improve SQLAlchemy type imports
1 parent e64087a commit a0d1fba

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

cardinal_pythonlib/sqlalchemy/schema.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,25 @@
6161
Index,
6262
Table,
6363
)
64-
from sqlalchemy.sql import sqltypes, text
6564
from sqlalchemy.sql.ddl import DDLElement
65+
from sqlalchemy.sql.expression import text
6666
from sqlalchemy.sql.sqltypes import (
6767
BigInteger,
6868
Boolean,
6969
Date,
7070
DateTime,
7171
Double,
72+
Enum,
7273
Float,
7374
Integer,
75+
LargeBinary,
7476
Numeric,
7577
SmallInteger,
78+
String,
7679
Text,
7780
TypeEngine,
81+
Unicode,
82+
UnicodeText,
7883
)
7984
from sqlalchemy.sql.visitors import Visitable
8085

@@ -95,6 +100,14 @@
95100
# Constants
96101
# =============================================================================
97102

103+
# To avoid importing _Binary directly:
104+
if len(LargeBinary.__bases__) != 1:
105+
raise NotImplementedError(
106+
"Unexpectedly, SQLAlchemy's LargeBinary class has more than one base "
107+
"class"
108+
)
109+
BinaryBaseClass = LargeBinary.__bases__[0]
110+
98111
VisitableType = Type[Visitable] # for SQLAlchemy 2.0
99112

100113
MIN_TEXT_LENGTH_FOR_FREETEXT_INDEX = 1000
@@ -1100,29 +1113,29 @@ def convert_sqla_type_for_dialect(
11001113
# -------------------------------------------------------------------------
11011114
# Text
11021115
# -------------------------------------------------------------------------
1103-
if isinstance(coltype, sqltypes.Enum):
1104-
return sqltypes.String(length=coltype.length)
1105-
if isinstance(coltype, sqltypes.UnicodeText):
1116+
if isinstance(coltype, Enum):
1117+
return String(length=coltype.length)
1118+
if isinstance(coltype, UnicodeText):
11061119
# Unbounded Unicode text.
11071120
# Includes derived classes such as mssql.base.NTEXT.
1108-
return sqltypes.UnicodeText()
1109-
if isinstance(coltype, sqltypes.Text):
1121+
return UnicodeText()
1122+
if isinstance(coltype, Text):
11101123
# Unbounded text, more generally. (UnicodeText inherits from Text.)
11111124
# Includes sqltypes.TEXT.
1112-
return sqltypes.Text()
1125+
return Text()
11131126
# Everything inheriting from String has a length property, but can be None.
11141127
# There are types that can be unlimited in SQL Server, e.g. VARCHAR(MAX)
11151128
# and NVARCHAR(MAX), that MySQL needs a length for. (Failure to convert
11161129
# gives e.g.: 'NVARCHAR requires a length on dialect mysql'.)
1117-
if isinstance(coltype, sqltypes.Unicode):
1130+
if isinstance(coltype, Unicode):
11181131
# Includes NVARCHAR(MAX) in SQL -> NVARCHAR() in SQLAlchemy.
11191132
if (coltype.length is None and to_mysql) or expand_for_scrubbing:
1120-
return sqltypes.UnicodeText()
1133+
return UnicodeText()
11211134
# The most general case; will pick up any other string types.
1122-
if isinstance(coltype, sqltypes.String):
1135+
if isinstance(coltype, String):
11231136
# Includes VARCHAR(MAX) in SQL -> VARCHAR() in SQLAlchemy
11241137
if (coltype.length is None and to_mysql) or expand_for_scrubbing:
1125-
return sqltypes.Text()
1138+
return Text()
11261139
if strip_collation:
11271140
return remove_collation(coltype)
11281141
return coltype
@@ -1168,28 +1181,26 @@ def is_sqlatype_binary(coltype: Union[TypeEngine, VisitableType]) -> bool:
11681181
Is the SQLAlchemy column type a binary type?
11691182
"""
11701183
# Several binary types inherit internally from _Binary, making that the
1171-
# easiest to check.
1184+
# easiest to check. We obtain BinaryBaseClass (= _Binary) as above.
11721185
coltype = coltype_as_typeengine(coltype)
11731186
# noinspection PyProtectedMember
1174-
return isinstance(coltype, sqltypes._Binary)
1187+
return isinstance(coltype, BinaryBaseClass)
11751188

11761189

11771190
def is_sqlatype_date(coltype: Union[TypeEngine, VisitableType]) -> bool:
11781191
"""
11791192
Is the SQLAlchemy column type a date type?
11801193
"""
11811194
coltype = coltype_as_typeengine(coltype)
1182-
return isinstance(coltype, sqltypes.DateTime) or isinstance(
1183-
coltype, sqltypes.Date
1184-
)
1195+
return isinstance(coltype, DateTime) or isinstance(coltype, Date)
11851196

11861197

11871198
def is_sqlatype_integer(coltype: Union[TypeEngine, VisitableType]) -> bool:
11881199
"""
11891200
Is the SQLAlchemy column type an integer type?
11901201
"""
11911202
coltype = coltype_as_typeengine(coltype)
1192-
return isinstance(coltype, sqltypes.Integer)
1203+
return isinstance(coltype, Integer)
11931204

11941205

11951206
def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool:
@@ -1200,15 +1211,15 @@ def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool:
12001211
Note that integers don't count as Numeric!
12011212
"""
12021213
coltype = coltype_as_typeengine(coltype)
1203-
return isinstance(coltype, sqltypes.Numeric) # includes Float, Decimal
1214+
return isinstance(coltype, Numeric) # includes Float, Decimal
12041215

12051216

12061217
def is_sqlatype_string(coltype: Union[TypeEngine, VisitableType]) -> bool:
12071218
"""
12081219
Is the SQLAlchemy column type a string type?
12091220
"""
12101221
coltype = coltype_as_typeengine(coltype)
1211-
return isinstance(coltype, sqltypes.String)
1222+
return isinstance(coltype, String)
12121223

12131224

12141225
def is_sqlatype_text_of_length_at_least(
@@ -1220,7 +1231,7 @@ def is_sqlatype_text_of_length_at_least(
12201231
length?
12211232
"""
12221233
coltype = coltype_as_typeengine(coltype)
1223-
if not isinstance(coltype, sqltypes.String):
1234+
if not isinstance(coltype, String):
12241235
return False # not a string/text type at all
12251236
if coltype.length is None:
12261237
return True # string of unlimited length
@@ -1260,9 +1271,9 @@ def does_sqlatype_require_index_len(
12601271
https://dev.mysql.com/doc/refman/5.7/en/create-index.html.)
12611272
"""
12621273
coltype = coltype_as_typeengine(coltype)
1263-
if isinstance(coltype, sqltypes.Text):
1274+
if isinstance(coltype, Text):
12641275
return True
1265-
if isinstance(coltype, sqltypes.LargeBinary):
1276+
if isinstance(coltype, LargeBinary):
12661277
return True
12671278
return False
12681279

0 commit comments

Comments
 (0)