Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,14 @@ def autogen_column_reflect(self, inspector, table, column_info):

"""

def autogen_table_reflect(self, inspector, table):
"""A hook that is called when a Table is reflected from the
database during the autogenerate process.

Dialects can elect to modify the information gathered here.

"""

def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Expand Down
11 changes: 11 additions & 0 deletions alembic/ddl/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ def autogen_column_reflect(
):
column_info["default"] = "(%s)" % (column_info["default"],)

def autogen_table_reflect(self, inspector, table):
sql_text = sql.text(
"SELECT sql FROM sqlite_master WHERE name=:name AND type='table'"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should come from sqlalchemy, not be in alembic.

cc @zzzeek

)
res = inspector.bind.execute(sql_text, {"name": table.name}).scalar()
if res:
if re.search(r"\bSTRICT\b\s*;?\s*$", res, re.I):
table.kwargs["sqlite_strict"] = True
if re.search(r"\bWITHOUT ROWID\b", res, re.I):
table.kwargs["sqlite_with_rowid"] = False

def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
Expand Down
9 changes: 8 additions & 1 deletion alembic/operations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def flush(self) -> None:
*self.reflect_args,
**self.reflect_kwargs,
)
from sqlalchemy import inspect as sqla_inspect

self.operations.impl.autogen_table_reflect(
sqla_inspect(self.operations.get_bind()),
existing_table,
)
reflected = True

batch_impl = ApplyBatchImpl(
Expand Down Expand Up @@ -642,7 +648,8 @@ def drop_column(
and kw["existing_type"].name # type:ignore[attr-defined]
):
self.named_constraints.pop(
kw["existing_type"].name, None # type:ignore[attr-defined]
kw["existing_type"].name,
None, # type:ignore[attr-defined]
)

def create_column_comment(self, column):
Expand Down
54 changes: 46 additions & 8 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _assert_impl(
"INSERT INTO %(schema)s%(temp_name)s (%(colnames)s) "
"SELECT %(tname_colnames)s FROM %(schema)stname" % args,
"DROP TABLE %(schema)stname" % args,
"ALTER TABLE %(schema)s%(temp_name)s "
"RENAME TO %(schema)stname" % args,
"ALTER TABLE %(schema)s%(temp_name)s RENAME TO %(schema)stname"
% args,
]
)
if idx_stmt:
Expand Down Expand Up @@ -775,8 +775,9 @@ def test_add_fk(self):
new_table = self._assert_impl(
impl,
colnames=["id", "x", "y", "user_id"],
ddl_contains="CONSTRAINT fk1 FOREIGN KEY(user_id) "
'REFERENCES "user" (id)',
ddl_contains=(
"CONSTRAINT fk1 FOREIGN KEY(user_id) " 'REFERENCES "user" (id)'
),
)
eq_(
list(new_table.c.user_id.foreign_keys)[0]._get_colspec(), "user.id"
Expand Down Expand Up @@ -1589,8 +1590,7 @@ def test_fk_points_to_me_recreate(self):

@exclusions.only_on("sqlite")
@exclusions.fails(
"intentionally asserting that this "
"doesn't work w/ pragma foreign keys"
"intentionally asserting that this doesn't work w/ pragma foreign keys"
)
def test_fk_points_to_me_sqlite_refinteg(self):
with self._sqlite_referential_integrity():
Expand Down Expand Up @@ -1635,8 +1635,7 @@ def test_selfref_fk_recreate(self):

@exclusions.only_on("sqlite")
@exclusions.fails(
"intentionally asserting that this "
"doesn't work w/ pragma foreign keys"
"intentionally asserting that this doesn't work w/ pragma foreign keys"
)
def test_selfref_fk_sqlite_refinteg(self):
with self._sqlite_referential_integrity():
Expand Down Expand Up @@ -2298,6 +2297,45 @@ def test_create_drop_index(self):
insp = inspect(self.conn)
eq_(insp.get_indexes("foo"), [])

def test_sqlite_batch_strict(self):
"""test that STRICT is persisted in batch mode. See #1758"""
t = Table(
"t",
self.metadata,
Column("id", Integer, primary_key=True),
Column("data", Integer),
sqlite_strict=True,
)
with self.conn.begin():
t.create(self.conn)

with self.op.batch_alter_table("t", recreate="always") as batch_op:
batch_op.drop_column("data")

sql = self.conn.scalar(
text("SELECT sql FROM sqlite_master WHERE name='t'")
)
assert "STRICT" in sql

def test_sqlite_batch_without_rowid(self):
"""test that WITHOUT ROWID is persisted in batch mode. See #1758"""
t2 = Table(
"t2",
self.metadata,
Column("id", Integer, primary_key=True),
sqlite_with_rowid=False,
)
with self.conn.begin():
t2.create(self.conn)

with self.op.batch_alter_table("t2", recreate="always") as batch_op:
batch_op.add_column(Column("new_col", Integer))

sql = self.conn.scalar(
text("SELECT sql FROM sqlite_master WHERE name='t2'")
)
assert "WITHOUT ROWID" in sql


class BatchRoundTripMySQLTest(BatchRoundTripTest):
__only_on__ = "mysql", "mariadb"
Expand Down