Skip to content

Commit

Permalink
Improve typings
Browse files Browse the repository at this point in the history
Misc changes to improve the typing of alembic:
- Improve typing of the revision parameter in various command functions.
- Properly type the :paramref:`.Operations.create_check_constraint.condition`
  parameter of :meth:`.Operations.create_check_constraint` to accept boolean
  expressions.

Fixes: #930
Fixes: #1266
Change-Id: I9e8249bbd34f9f0b388b79e75b76e75f8347d8ee
  • Loading branch information
CaselIT committed Sep 7, 2023
1 parent d10b9d9 commit 5628a22
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 49 deletions.
7 changes: 4 additions & 3 deletions alembic/autogenerate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..runtime.migration import MigrationContext
from ..script.base import Script
from ..script.base import ScriptDirectory
from ..script.revision import _GetRevArg


def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
Expand Down Expand Up @@ -555,18 +556,18 @@ def _to_script(
)

def run_autogenerate(
self, rev: tuple, migration_context: MigrationContext
self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, True)

def run_no_autogenerate(
self, rev: tuple, migration_context: MigrationContext
self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, False)

def _run_environment(
self,
rev: tuple,
rev: _GetRevArg,
migration_context: MigrationContext,
autogenerate: bool,
) -> None:
Expand Down
13 changes: 6 additions & 7 deletions alembic/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
if TYPE_CHECKING:
from alembic.config import Config
from alembic.script.base import Script
from alembic.script.revision import _RevIdType
from .runtime.environment import ProcessRevisionDirectiveFn


Expand Down Expand Up @@ -124,7 +125,7 @@ def revision(
sql: bool = False,
head: str = "head",
splice: bool = False,
branch_label: Optional[str] = None,
branch_label: Optional[_RevIdType] = None,
version_path: Optional[str] = None,
rev_id: Optional[str] = None,
depends_on: Optional[str] = None,
Expand Down Expand Up @@ -244,9 +245,7 @@ def retrieve_migrations(rev, context):
return scripts


def check(
config: "Config",
) -> None:
def check(config: "Config") -> None:
"""Check if revision command with autogenerate has pending upgrade ops.
:param config: a :class:`.Config` object.
Expand Down Expand Up @@ -302,9 +301,9 @@ def retrieve_migrations(rev, context):

def merge(
config: Config,
revisions: str,
revisions: _RevIdType,
message: Optional[str] = None,
branch_label: Optional[str] = None,
branch_label: Optional[_RevIdType] = None,
rev_id: Optional[str] = None,
) -> Optional[Script]:
"""Merge two revisions together. Creates a new migration file.
Expand Down Expand Up @@ -623,7 +622,7 @@ def display_version(rev, context):

def stamp(
config: Config,
revision: str,
revision: _RevIdType,
sql: bool = False,
tag: Optional[str] = None,
purge: bool = False,
Expand Down
8 changes: 3 additions & 5 deletions alembic/ddl/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
Expand Down Expand Up @@ -513,7 +513,7 @@ def __init__(
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
where: Optional[Union[BinaryExpression, str]] = None,
where: Optional[Union[ColumnElement[bool], str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
Expand All @@ -538,9 +538,7 @@ def from_constraint( # type:ignore[override]
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=cast(
"Optional[Union[BinaryExpression, str]]", constraint.where
),
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
Expand Down
4 changes: 2 additions & 2 deletions alembic/op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from sqlalchemy.sql.expression import Update

if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
Expand Down Expand Up @@ -481,7 +481,7 @@ def bulk_insert(
def create_check_constraint(
constraint_name: Optional[str],
table_name: str,
condition: Union[str, BinaryExpression, TextClause],
condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
Expand Down
6 changes: 3 additions & 3 deletions alembic/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from sqlalchemy import Table
from sqlalchemy.engine import Connection
from sqlalchemy.sql.expression import BinaryExpression
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.sql.expression import TextClause
from sqlalchemy.sql.expression import Update
Expand Down Expand Up @@ -861,7 +861,7 @@ def create_check_constraint(
self,
constraint_name: Optional[str],
table_name: str,
condition: Union[str, BinaryExpression, TextClause],
condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def alter_column(
def create_check_constraint(
self,
constraint_name: str,
condition: Union[str, BinaryExpression, TextClause],
condition: Union[str, ColumnElement[bool], TextClause],
**kw: Any,
) -> None:
"""Issue a "create check constraint" instruction using the
Expand Down
5 changes: 2 additions & 3 deletions alembic/operations/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.dml import Update
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import quoted_name
Expand Down Expand Up @@ -788,7 +787,7 @@ def create_check_constraint(
operations: Operations,
constraint_name: Optional[str],
table_name: str,
condition: Union[str, BinaryExpression, TextClause],
condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
Expand Down Expand Up @@ -841,7 +840,7 @@ def batch_create_check_constraint(
cls,
operations: BatchOperations,
constraint_name: str,
condition: Union[str, BinaryExpression, TextClause],
condition: Union[str, ColumnElement[bool], TextClause],
**kw: Any,
) -> None:
"""Issue a "create check constraint" instruction using the
Expand Down
4 changes: 2 additions & 2 deletions alembic/runtime/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ def merge_branch_idents(
self.to_revisions[0],
)

def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]:
other_heads = set(heads).difference([self.revision.revision])
if other_heads:
ancestors = {
Expand All @@ -1171,7 +1171,7 @@ def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
return self.to_revisions

def unmerge_branch_idents(
self, heads: Collection[str]
self, heads: Set[str]
) -> Tuple[str, str, Tuple[str, ...]]:
to_revisions = self._unmerge_to_revisions(heads)

Expand Down
7 changes: 3 additions & 4 deletions alembic/script/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..util import not_none

if TYPE_CHECKING:
from .revision import _GetRevArg
from .revision import _RevIdType
from .revision import Revision
from ..config import Config
Expand Down Expand Up @@ -296,7 +297,7 @@ def walk_revisions(
):
yield cast(Script, rev)

def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
Expand Down Expand Up @@ -630,8 +631,7 @@ def generate_revision(
self,
revid: str,
message: Optional[str],
head: Optional[str] = None,
refresh: bool = False,
head: Optional[_RevIdType] = None,
splice: Optional[bool] = False,
branch_labels: Optional[_RevIdType] = None,
version_path: Optional[str] = None,
Expand All @@ -653,7 +653,6 @@ def generate_revision(
:param splice: if True, allow the "head" version to not be an
actual head; otherwise, the selected head must be a head
(e.g. endpoint) revision.
:param refresh: deprecated.
"""
if head is None:
Expand Down
51 changes: 31 additions & 20 deletions alembic/script/revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,25 @@
if TYPE_CHECKING:
from typing import Literal

_RevIdType = Union[str, Sequence[str]]
_RevIdType = Union[str, List[str], Tuple[str, ...]]
_GetRevArg = Union[
str,
List[Optional[str]],
Tuple[Optional[str], ...],
FrozenSet[Optional[str]],
Set[Optional[str]],
List[str],
Tuple[str, ...],
FrozenSet[str],
Set[str],
]
_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
_RevisionOrStr = Union["Revision", str]
_RevisionOrBase = Union["Revision", "Literal['base']"]
_InterimRevisionMapType = Dict[str, "Revision"]
_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
_T = TypeVar("_T", bound=Union[str, "Revision"])
_T = TypeVar("_T")
_TR = TypeVar("_TR", bound=Optional[_RevisionOrStr])

_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
_revision_illegal_chars = ["@", "-", "+"]
Expand Down Expand Up @@ -501,7 +513,7 @@ def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]:
return self.filter_for_lineage(self.bases, identifier)

def get_revisions(
self, id_: Union[str, Collection[Optional[str]], None]
self, id_: Optional[_GetRevArg]
) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
Expand All @@ -523,9 +535,7 @@ def get_revisions(
if isinstance(id_, (list, tuple, set, frozenset)):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
resolved_id, branch_label = self._resolve_revision_number(
id_ # type:ignore [arg-type]
)
resolved_id, branch_label = self._resolve_revision_number(id_)
if len(resolved_id) == 1:
try:
rint = int(resolved_id[0])
Expand Down Expand Up @@ -590,7 +600,7 @@ def _resolve_branch(self, branch_label: str) -> Optional[Revision]:

def _revision_for_ident(
self,
resolved_id: Union[str, Tuple[()]],
resolved_id: Union[str, Tuple[()], None],
check_branch: Optional[str] = None,
) -> Optional[Revision]:
branch_rev: Optional[Revision]
Expand Down Expand Up @@ -669,10 +679,10 @@ def _filter_into_branch_heads(

def filter_for_lineage(
self,
targets: Iterable[_T],
targets: Iterable[_TR],
check_against: Optional[str],
include_dependencies: bool = False,
) -> Tuple[_T, ...]:
) -> Tuple[_TR, ...]:
id_, branch_label = self._resolve_revision_number(check_against)

shares = []
Expand All @@ -691,7 +701,7 @@ def filter_for_lineage(

def _shares_lineage(
self,
target: _RevisionOrStr,
target: Optional[_RevisionOrStr],
test_against_revs: Sequence[_RevisionOrStr],
include_dependencies: bool = False,
) -> bool:
Expand Down Expand Up @@ -1211,7 +1221,7 @@ def _parse_upgrade_target(
# No relative destination, target is absolute.
return self.get_revisions(target)

current_revisions_tup: Union[str, Collection[Optional[str]], None]
current_revisions_tup: Union[str, Tuple[Optional[str], ...], None]
current_revisions_tup = util.to_tuple(current_revisions)

branch_label, symbol, relative_str = match.groups()
Expand All @@ -1224,7 +1234,8 @@ def _parse_upgrade_target(
start_revs = current_revisions_tup
if branch_label:
start_revs = self.filter_for_lineage(
self.get_revisions(current_revisions_tup), branch_label
self.get_revisions(current_revisions_tup), # type: ignore[arg-type] # noqa: E501
branch_label,
)
if not start_revs:
# The requested branch is not a head, so we need to
Expand Down Expand Up @@ -1577,8 +1588,8 @@ def __init__(

self.verify_rev_id(revision)
self.revision = revision
self.down_revision = tuple_rev_as_scalar(down_revision)
self.dependencies = tuple_rev_as_scalar(dependencies)
self.down_revision = tuple_rev_as_scalar(util.to_tuple(down_revision))
self.dependencies = tuple_rev_as_scalar(util.to_tuple(dependencies))
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)

Expand Down Expand Up @@ -1676,20 +1687,20 @@ def is_merge_point(self) -> bool:


@overload
def tuple_rev_as_scalar(
rev: Optional[Sequence[str]],
) -> Optional[Union[str, Sequence[str]]]:
def tuple_rev_as_scalar(rev: None) -> None:
...


@overload
def tuple_rev_as_scalar(
rev: Optional[Sequence[Optional[str]]],
) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
rev: Union[Tuple[_T, ...], List[_T]]
) -> Union[_T, Tuple[_T, ...], List[_T]]:
...


def tuple_rev_as_scalar(rev):
def tuple_rev_as_scalar(
rev: Optional[Sequence[_T]],
) -> Union[_T, Sequence[_T], None]:
if not rev:
return None
elif len(rev) == 1:
Expand Down
13 changes: 13 additions & 0 deletions docs/build/unreleased/improve_typing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. change::
:tags: typing
:tickets: 930

Improve typing of the revision parameter in various command functions.

.. change::
:tags: typing, bug
:tickets: 1266

Properly type the :paramref:`.Operations.create_check_constraint.condition`
parameter of :meth:`.Operations.create_check_constraint` to accept boolean
expressions.

0 comments on commit 5628a22

Please sign in to comment.