Skip to content

Commit f3821e4

Browse files
Fix invalid bool() conversion of a DDL object, and some additional tests/tidying
1 parent 0b9b005 commit f3821e4

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

cardinal_pythonlib/sqlalchemy/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def execute_ddl(
357357
ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER), and pass that
358358
DDL object to this function.
359359
"""
360-
assert bool(sql) ^ bool(ddl) # one or the other.
360+
assert bool(sql) ^ (ddl is not None) # one or the other.
361361
if sql:
362362
ddl = DDL(sql)
363363
with engine.connect() as connection:

cardinal_pythonlib/sqlalchemy/tests/schema_tests.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
from sqlalchemy.orm import declarative_base
3939
from sqlalchemy.schema import (
4040
Column,
41+
CreateTable,
4142
DDLElement,
4243
Index,
4344
MetaData,
4445
Sequence,
4546
Table,
4647
)
4748
from sqlalchemy.sql import table
49+
from sqlalchemy.sql.selectable import Select
4850
from sqlalchemy.sql.sqltypes import (
4951
BigInteger,
5052
Date,
@@ -65,6 +67,7 @@
6567
columns_equal,
6668
convert_sqla_type_for_dialect,
6769
does_sqlatype_require_index_len,
70+
execute_ddl,
6871
gen_columns_info,
6972
get_column_info,
7073
get_column_names,
@@ -287,8 +290,15 @@ def _view_doesnt_exist(ddl, target, connection, **kw):
287290
return not _view_exists(ddl, target, connection, **kw)
288291

289292

290-
def _view(name, metadata, selectable) -> Table:
291-
t = table(name)
293+
def _attach_view(
294+
tablename: str, metadata: MetaData, selectable: Select
295+
) -> None:
296+
"""
297+
Attaches a view to a table of the given name, such that the view (which is
298+
of "selectable") is created after the table is created, and dropped before
299+
the table is dropped, via listeners.
300+
"""
301+
t = table(tablename)
292302

293303
# noinspection PyProtectedMember
294304
t._columns._populate_separate_keys(
@@ -298,14 +308,15 @@ def _view(name, metadata, selectable) -> Table:
298308
event.listen(
299309
metadata,
300310
"after_create",
301-
CreateView(name, selectable).execute_if(callable_=_view_doesnt_exist),
311+
CreateView(tablename, selectable).execute_if(
312+
callable_=_view_doesnt_exist
313+
),
302314
)
303315
event.listen(
304316
metadata,
305317
"before_drop",
306-
DropView(name).execute_if(callable_=_view_exists),
318+
DropView(tablename).execute_if(callable_=_view_exists),
307319
)
308-
return t
309320

310321

311322
class MoreSchemaTests(unittest.TestCase):
@@ -322,17 +333,17 @@ def setUp(self) -> None:
322333
Column("name", String(50)),
323334
)
324335

325-
_view(
336+
_attach_view(
326337
"one",
327338
metadata,
328339
select(self.person.c.id.label("name")),
329340
)
330-
_view(
341+
_attach_view(
331342
"two",
332343
metadata,
333344
select(self.person.c.id.label("name")),
334345
)
335-
_view(
346+
_attach_view(
336347
"three",
337348
metadata,
338349
select(self.person.c.id.label("name")),
@@ -486,16 +497,16 @@ def setUp(self) -> None:
486497
self.engine = create_engine(
487498
SQLITE_MEMORY_URL, echo=self.echo, future=True
488499
)
489-
metadata = MetaData()
500+
self.metadata = MetaData()
490501
self.person = Table(
491502
"person",
492-
metadata,
503+
self.metadata,
493504
Column("id", Integer, primary_key=True, autoincrement=False),
494505
Column("name", String(50)),
495506
make_bigint_autoincrement_column("bigthing"),
496507
)
497508
with self.engine.begin() as conn:
498-
metadata.create_all(conn)
509+
self.metadata.create_all(conn)
499510

500511
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
501512
# get_single_int_autoincrement_colname (again)
@@ -527,8 +538,8 @@ def test_column_creation_ddl(self) -> None:
527538
"you", BigInteger, Sequence("dummy_name", start=1, increment=1)
528539
)
529540

530-
metadata = MetaData()
531-
t = Table("mytable", metadata)
541+
self.metadata = MetaData()
542+
t = Table("mytable", self.metadata)
532543
t.append_column(col1)
533544
t.append_column(col2)
534545
t.append_column(col3)
@@ -605,6 +616,21 @@ def test_column_lists_equal(self) -> None:
605616
self.assertFalse(column_lists_equal([a, b], [b, a]))
606617
self.assertFalse(column_lists_equal([a, b], [a]))
607618

619+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
620+
# execute_ddl
621+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
622+
def test_execute_ddl(self) -> None:
623+
sql = "CREATE TABLE x (a INT)"
624+
execute_ddl(self.engine, sql=sql)
625+
626+
ddl = CreateTable(Table("y", self.metadata, Column("z", Integer)))
627+
execute_ddl(self.engine, ddl=ddl)
628+
629+
with self.assertRaises(AssertionError):
630+
execute_ddl(self.engine, sql=sql, ddl=ddl) # both
631+
with self.assertRaises(AssertionError):
632+
execute_ddl(self.engine) # neither
633+
608634

609635
class SchemaAbstractTests(unittest.TestCase):
610636
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)