Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions alembic/autogenerate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,18 @@ def _make_foreign_key(
return const


def _make_check_constraint(
params: Dict[str, Any], conn_table: Table
) -> sa_schema.CheckConstraint:
const = sa_schema.CheckConstraint(
text(params["sqltext"]),
name=params["name"],
_autoattach=False,
)
const._set_parent_with_dispatch(conn_table)
return const


@contextlib.contextmanager
def _compare_columns(
schema: Optional[str],
Expand Down Expand Up @@ -479,6 +491,11 @@ def get_foreign_keys(self, *args, **kw):
self.inspector.get_foreign_keys(*args, **kw)
)

def get_check_constraints(self, *args, **kw):
return self._apply_reflectinfo_conv(
self.inspector.get_check_constraints(*args, **kw)
)

def reflect_table(self, table, *, include_columns):
self.inspector.reflect_table(table, include_columns=include_columns)

Expand Down Expand Up @@ -1368,3 +1385,152 @@ def _compare_table_comment(
schema=schema,
)
)


@comparators.dispatch_for("table")
def _compare_check_constraints(
autogen_context: AutogenContext,
modify_table_ops: ModifyTableOps,
schema: Optional[str],
tname: Union[quoted_name, str],
conn_table: Optional[Table],
metadata_table: Optional[Table],
) -> None:
if not autogen_context.opts.get("compare_check_constraints", False):
return

if conn_table is None or metadata_table is None:
return

inspector = autogen_context.inspector

metadata_check_constraints = {
ck
for ck in metadata_table.constraints
if isinstance(ck, sa_schema.CheckConstraint)
and not getattr(ck, "_type_bound", False)
}

for ck in metadata_check_constraints:
if ck.name is None:
raise ValueError(
f"Unnamed check constraint on table {tname!r} cannot be "
f"compared. When compare_check_constraints is enabled, all "
f"check constraints must have explicit names. "
f"Constraint SQL: {ck.sqltext}"
)

try:
conn_cks_list = _InspectorConv(inspector).get_check_constraints(
tname, schema=schema
)
except NotImplementedError:
return

def _is_type_bound_constraint(ck_dict: Dict[str, Any]) -> bool:
name = ck_dict.get("name")
sqltext = ck_dict.get("sqltext", "")
if not name:
return False
if "::text = ANY ((ARRAY[" in sqltext:
return True
if "::text = ANY (ARRAY[" in sqltext:
return True
return False

conn_cks_list = [
ck
for ck in conn_cks_list
if autogen_context.run_name_filters(
ck["name"],
"check_constraint",
{"table_name": tname, "schema_name": schema},
)
and not _is_type_bound_constraint(ck)
]

conn_check_constraints = {
_make_check_constraint(ck_def, conn_table) for ck_def in conn_cks_list
}

impl = autogen_context.migration_context.impl

metadata_cks_sig = {
impl._create_metadata_constraint_sig(ck)
for ck in metadata_check_constraints
}

conn_cks_sig = {
impl._create_reflected_constraint_sig(ck)
for ck in conn_check_constraints
}

metadata_cks_by_name: Dict[str, Any] = {
str(c.name): c
for c in metadata_cks_sig
if sqla_compat.constraint_name_string(c.name)
}
conn_cks_by_name = {
str(c.name): c
for c in conn_cks_sig
if sqla_compat.constraint_name_string(c.name)
}

def _add_ck(obj, compare_to):
if autogen_context.run_object_filters(
obj.const, obj.name, "check_constraint", False, compare_to
):
modify_table_ops.ops.append(
ops.CreateCheckConstraintOp.from_constraint(obj.const)
)
log.info(
"Detected added check constraint %r on table %r",
obj.name,
tname,
)

def _remove_ck(obj, compare_to):
if autogen_context.run_object_filters(
obj.const, obj.name, "check_constraint", True, compare_to
):
modify_table_ops.ops.append(
ops.DropConstraintOp.from_constraint(obj.const)
)
log.info(
"Detected removed check constraint %r on table %r",
obj.name,
tname,
)

for removed_name in sorted(
set(conn_cks_by_name).difference(metadata_cks_by_name)
):
const = conn_cks_by_name[removed_name]
compare_to = (
metadata_cks_by_name[removed_name].const
if removed_name in metadata_cks_by_name
else None
)
_remove_ck(const, compare_to)

for added_name in sorted(
set(metadata_cks_by_name).difference(conn_cks_by_name)
):
const = metadata_cks_by_name[added_name]
compare_to = (
conn_cks_by_name[added_name].const
if added_name in conn_cks_by_name
else None
)
_add_ck(const, compare_to)

for existing_name in sorted(
set(metadata_cks_by_name).intersection(conn_cks_by_name)
):
metadata_ck = metadata_cks_by_name[existing_name]
conn_ck = conn_cks_by_name[existing_name]

comparison = metadata_ck.compare_to_reflected(conn_ck)
if comparison.is_different:
_remove_ck(conn_ck, metadata_ck.const)
_add_ck(metadata_ck, conn_ck.const)
37 changes: 35 additions & 2 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,41 @@ def _add_pk_constraint(constraint, autogen_context):


@renderers.dispatch_for(ops.CreateCheckConstraintOp)
def _add_check_constraint(constraint, autogen_context):
raise NotImplementedError()
def _add_check_constraint(
autogen_context: AutogenContext, op: ops.CreateCheckConstraintOp
) -> str:
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
args.append(repr(_ident(op.table_name)))

if isinstance(op.condition, str):
condition_text = op.condition
elif hasattr(op.condition, "text"):
condition_text = op.condition.text
else:
from sqlalchemy.dialects import postgresql

condition_text = op.condition.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
).string
args.append(repr(condition_text))

if op.schema and not autogen_context._has_batch:
args.append("schema=%r" % op.schema)

constraint = op.to_constraint()
dialect_kwargs = _render_dialect_kwargs_items(
autogen_context, constraint.dialect_kwargs
)

return "%(prefix)screate_check_constraint(%(args)s%(dialect_kwargs)s)" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
"dialect_kwargs": (
", " + ", ".join(dialect_kwargs) if dialect_kwargs else ""
),
}


@renderers.dispatch_for(ops.DropConstraintOp)
Expand Down
27 changes: 23 additions & 4 deletions alembic/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def configure(
"index",
"unique_constraint",
"foreign_key_constraint",
"check_constraint",
],
MutableMapping[
Literal[
Expand All @@ -138,6 +139,7 @@ def configure(
"index",
"unique_constraint",
"foreign_key_constraint",
"check_constraint",
],
bool,
Optional[SchemaItem],
Expand Down Expand Up @@ -183,6 +185,7 @@ def configure(
Optional[bool],
],
] = False,
compare_check_constraints: bool = False,
render_item: Optional[
Callable[[str, Any, AutogenContext], Union[str, Literal[False]]]
] = None,
Expand Down Expand Up @@ -391,6 +394,20 @@ def configure(

:paramref:`.EnvironmentContext.configure.compare_type`

:param compare_check_constraints: Indicates check constraint comparison
behavior during an autogenerate operation. Defaults to ``False``
which disables check constraint comparison. Set to ``True`` to
turn on check constraint comparison, which will detect added,
removed, and modified named check constraints.

This feature requires that check constraints have explicit names.
Unnamed check constraints will not be detected.

Check constraint comparison may produce false positives if the
database normalizes the SQL text differently from how it was
originally defined. This is an opt-in feature due to potential
compatibility issues across different database backends.

:param include_name: A callable function which is given
the chance to return ``True`` or ``False`` for any database reflected
object based on its name, including database schema names when
Expand All @@ -404,7 +421,8 @@ def configure(
database connection.
* ``type``: a string describing the type of object; currently
``"schema"``, ``"table"``, ``"column"``, ``"index"``,
``"unique_constraint"``, or ``"foreign_key_constraint"``
``"unique_constraint"``, ``"foreign_key_constraint"``, or
``"check_constraint"``
* ``parent_names``: a dictionary of "parent" object names, that are
relative to the name being given. Keys in this dictionary may
include: ``"schema_name"``, ``"table_name"`` or
Expand Down Expand Up @@ -443,14 +461,15 @@ def configure(
* ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such
as a :class:`~sqlalchemy.schema.Table`,
:class:`~sqlalchemy.schema.Column`,
:class:`~sqlalchemy.schema.Index`
:class:`~sqlalchemy.schema.Index`,
:class:`~sqlalchemy.schema.UniqueConstraint`,
or :class:`~sqlalchemy.schema.ForeignKeyConstraint` object
:class:`~sqlalchemy.schema.ForeignKeyConstraint`, or
:class:`~sqlalchemy.schema.CheckConstraint` object
* ``name``: the name of the object. This is typically available
via ``object.name``.
* ``type``: a string describing the type of object; currently
``"table"``, ``"column"``, ``"index"``, ``"unique_constraint"``,
or ``"foreign_key_constraint"``
``"foreign_key_constraint"``, or ``"check_constraint"``
* ``reflected``: ``True`` if the given object was produced based on
table reflection, ``False`` if it's from a local :class:`.MetaData`
object.
Expand Down
68 changes: 68 additions & 0 deletions alembic/ddl/_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from typing import TypeVar
from typing import Union

import re

from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
Expand Down Expand Up @@ -86,6 +89,7 @@ class _constraint_sig(Generic[_C]):
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_ck: ClassVar[bool] = False

_is_metadata: bool

Expand Down Expand Up @@ -327,3 +331,67 @@ def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:

def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk


class _ck_constraint_sig(_constraint_sig[CheckConstraint]):
_is_ck: ClassVar[bool] = True
_VALID_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")

@classmethod
def _register(cls) -> None:
_clsreg["check_constraint"] = cls
_clsreg["table_or_column_check_constraint"] = cls
_clsreg["column_check_constraint"] = cls

def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: CheckConstraint,
) -> None:
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._is_metadata = is_metadata
self._sig = (self._normalize_sqltext(str(const.sqltext)),)

if is_metadata and self.name is not None:
if not self._VALID_NAME_PATTERN.match(self.name):
table_name = (
const.table.name
if const.table is not None
else "<unknown>"
)
raise ValueError(
f"Check constraint name {self.name!r} on table "
f"{table_name!r} contains invalid characters. "
f"Constraint names must contain only alphanumeric "
f"characters and underscores, and must start with "
f"a letter or underscore."
)

@staticmethod
def _normalize_sqltext(sqltext: str) -> str:
normalized = re.sub(r"\s+", " ", sqltext.strip().lower())
normalized = re.sub(r"\(\s+", "(", normalized)
normalized = re.sub(r"\s+\)", ")", normalized)
return normalized

@property
def sqltext(self) -> str:
return str(self.const.sqltext)

def _compare_to_reflected(
self, other: _constraint_sig[CheckConstraint]
) -> ComparisonResult:
assert self._is_metadata
assert is_ck_sig(other)
return self.impl.compare_check_constraint(self.const, other.const)

@util.memoized_property
def unnamed_no_options(self) -> Tuple[Any, ...]:
return self._sig


def is_ck_sig(sig: _constraint_sig) -> TypeGuard[_ck_constraint_sig]:
return sig._is_ck
Loading