Skip to content

Commit

Permalink
Define type for generic classes
Browse files Browse the repository at this point in the history
Fixed typing use of :class:`~sqlalchemy.schema.Column` and other
generic SQLAlchemy classes.

Fixes: #1246
Change-Id: I5ee80395d626894a52e3395c9986213289576355
  • Loading branch information
CaselIT committed May 16, 2023
1 parent 95adff6 commit 6ad07e2
Show file tree
Hide file tree
Showing 18 changed files with 78 additions and 55 deletions.
24 changes: 12 additions & 12 deletions alembic/autogenerate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,8 @@ def _compare_nullable(
schema: Optional[str],
tname: Union[quoted_name, str],
cname: Union[quoted_name, str],
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> None:

metadata_col_nullable = metadata_col.nullable
Expand Down Expand Up @@ -968,8 +968,8 @@ def _setup_autoincrement(
schema: Optional[str],
tname: Union[quoted_name, str],
cname: quoted_name,
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> None:

if metadata_col.table._autoincrement_column is metadata_col:
Expand All @@ -987,8 +987,8 @@ def _compare_type(
schema: Optional[str],
tname: Union[quoted_name, str],
cname: Union[quoted_name, str],
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> None:

conn_type = conn_col.type
Expand Down Expand Up @@ -1060,8 +1060,8 @@ def _compare_computed_default(
schema: Optional[str],
tname: str,
cname: str,
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> None:
rendered_metadata_default = str(
cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
Expand Down Expand Up @@ -1126,8 +1126,8 @@ def _compare_server_default(
schema: Optional[str],
tname: Union[quoted_name, str],
cname: Union[quoted_name, str],
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> Optional[bool]:

metadata_default = metadata_col.server_default
Expand Down Expand Up @@ -1215,8 +1215,8 @@ def _compare_column_comment(
schema: Optional[str],
tname: Union[quoted_name, str],
cname: quoted_name,
conn_col: Column,
metadata_col: Column,
conn_col: Column[Any],
metadata_col: Column[Any],
) -> Optional[Literal[False]]:

assert autogen_context.dialect is not None
Expand Down
8 changes: 6 additions & 2 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,9 @@ def _user_defined_render(
return False


def _render_column(column: Column, autogen_context: AutogenContext) -> str:
def _render_column(
column: Column[Any], autogen_context: AutogenContext
) -> str:
rendered = _user_defined_render("column", column, autogen_context)
if rendered is not False:
return rendered
Expand Down Expand Up @@ -727,7 +729,9 @@ def _should_render_server_default_positionally(server_default: Any) -> bool:


def _render_server_default(
default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]],
default: Optional[
Union[FetchedValue, str, TextClause, ColumnElement[Any]]
],
autogen_context: AutogenContext,
repr_: bool = True,
) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions alembic/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def configure(
Callable[
[
MigrationContext,
Column,
Column,
Column[Any],
Column[Any],
Optional[str],
Optional[FetchedValue],
Optional[str],
Expand Down
6 changes: 3 additions & 3 deletions alembic/ddl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class AddColumn(AlterTable):
def __init__(
self,
name: str,
column: Column,
column: Column[Any],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(name, schema=schema)
Expand All @@ -159,7 +159,7 @@ def __init__(

class DropColumn(AlterTable):
def __init__(
self, name: str, column: Column, schema: Optional[str] = None
self, name: str, column: Column[Any], schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
Expand Down Expand Up @@ -320,7 +320,7 @@ def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)


def add_column(compiler: DDLCompiler, column: Column, **kw) -> str:
def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)

const = " ".join(
Expand Down
8 changes: 4 additions & 4 deletions alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,15 @@ def alter_column(
def add_column(
self,
table_name: str,
column: Column,
column: Column[Any],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))

def drop_column(
self,
table_name: str,
column: Column,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
Expand Down Expand Up @@ -388,7 +388,7 @@ def create_table_comment(self, table: Table) -> None:
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))

def create_column_comment(self, column: ColumnElement) -> None:
def create_column_comment(self, column: ColumnElement[Any]) -> None:
self._exec(schema.SetColumnComment(column))

def drop_index(self, index: Index) -> None:
Expand Down Expand Up @@ -526,7 +526,7 @@ def _column_args_match(
return True

def compare_type(
self, inspector_column: Column, metadata_column: Column
self, inspector_column: Column[Any], metadata_column: Column
) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
Expand Down
10 changes: 6 additions & 4 deletions alembic/ddl/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def bulk_insert( # type:ignore[override]
def drop_column(
self,
table_name: str,
column: Column,
column: Column[Any],
schema: Optional[str] = None,
**kw,
) -> None:
Expand Down Expand Up @@ -273,7 +273,7 @@ class _ExecDropConstraint(Executable, ClauseElement):
def __init__(
self,
tname: str,
colname: Union[Column, str],
colname: Union[Column[Any], str],
type_: str,
schema: Optional[str],
) -> None:
Expand All @@ -287,7 +287,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement):
inherit_cache = False

def __init__(
self, tname: str, colname: Column, schema: Optional[str]
self, tname: str, colname: Column[Any], schema: Optional[str]
) -> None:
self.tname = tname
self.colname = colname
Expand Down Expand Up @@ -347,7 +347,9 @@ def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
)


def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str:
def mssql_add_column(
compiler: MSDDLCompiler, column: Column[Any], **kw
) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)


Expand Down
2 changes: 1 addition & 1 deletion alembic/ddl/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)


def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str:
def add_column(compiler: OracleDDLCompiler, column: Column[Any], **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)


Expand Down
6 changes: 4 additions & 2 deletions alembic/ddl/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
where: Optional[Union[BinaryExpression, str]] = None,
schema: Optional[str] = None,
Expand Down Expand Up @@ -706,7 +706,9 @@ def do_expr_where_opts():


def _render_potential_column(
value: Union[ColumnClause, Column, TextClause, FunctionElement],
value: Union[
ColumnClause[Any], Column[Any], TextClause, FunctionElement[Any]
],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
Expand Down
6 changes: 3 additions & 3 deletions alembic/ddl/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def drop_constraint(self, const: Constraint):

def compare_server_default(
self,
inspector_column: Column,
metadata_column: Column,
inspector_column: Column[Any],
metadata_column: Column[Any],
rendered_metadata_default: Optional[str],
rendered_inspector_default: Optional[str],
) -> bool:
Expand Down Expand Up @@ -173,7 +173,7 @@ def render_ddl_sql_expr(

def cast_for_batch_migrate(
self,
existing: Column,
existing: Column[Any],
existing_transfer: Dict[str, Union[TypeEngine, Cast]],
new_type: TypeEngine,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion alembic/op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ _T = TypeVar("_T")
### end imports ###

def add_column(
table_name: str, column: Column, *, schema: Optional[str] = None
table_name: str, column: Column[Any], *, schema: Optional[str] = None
) -> None:
"""Issue an "add column" instruction using the current
migration context.
Expand Down
4 changes: 2 additions & 2 deletions alembic/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class Operations(AbstractOperations):
def add_column(
self,
table_name: str,
column: Column,
column: Column[Any],
*,
schema: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -1574,7 +1574,7 @@ def _noop(self, operation):

def add_column(
self,
column: Column,
column: Column[Any],
*,
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
Expand Down
9 changes: 6 additions & 3 deletions alembic/operations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str:

def _grab_table_elements(self) -> None:
schema = self.table.schema
self.columns: Dict[str, Column] = OrderedDict()
self.columns: Dict[str, Column[Any]] = OrderedDict()
for c in self.table.c:
c_copy = _copy(c, schema=schema)
c_copy.unique = c_copy.index = False
Expand Down Expand Up @@ -607,7 +607,7 @@ def _setup_dependencies_for_add_column(
def add_column(
self,
table_name: str,
column: Column,
column: Column[Any],
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
**kw,
Expand All @@ -621,7 +621,10 @@ def add_column(
self.column_transfers[column.name] = {}

def drop_column(
self, table_name: str, column: Union[ColumnClause, Column], **kw
self,
table_name: str,
column: Union[ColumnClause[Any], Column[Any]],
**kw,
) -> None:
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
Expand Down
14 changes: 7 additions & 7 deletions alembic/operations/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,7 +1994,7 @@ class AddColumnOp(AlterTableOp):
def __init__(
self,
table_name: str,
column: Column,
column: Column[Any],
*,
schema: Optional[str] = None,
**kw: Any,
Expand All @@ -2010,7 +2010,7 @@ def reverse(self) -> DropColumnOp:

def to_diff_tuple(
self,
) -> Tuple[str, Optional[str], str, Column]:
) -> Tuple[str, Optional[str], str, Column[Any]]:
return ("add_column", self.schema, self.table_name, self.column)

def to_column(self) -> Column:
Expand All @@ -2025,7 +2025,7 @@ def from_column_and_tablename(
cls,
schema: Optional[str],
tname: str,
col: Column,
col: Column[Any],
) -> AddColumnOp:
return cls(tname, col, schema=schema)

Expand All @@ -2034,7 +2034,7 @@ def add_column(
cls,
operations: Operations,
table_name: str,
column: Column,
column: Column[Any],
*,
schema: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -2123,7 +2123,7 @@ def add_column(
def batch_add_column(
cls,
operations: BatchOperations,
column: Column,
column: Column[Any],
*,
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
Expand Down Expand Up @@ -2173,7 +2173,7 @@ def __init__(

def to_diff_tuple(
self,
) -> Tuple[str, Optional[str], str, Column]:
) -> Tuple[str, Optional[str], str, Column[Any]]:
return (
"remove_column",
self.schema,
Expand All @@ -2197,7 +2197,7 @@ def from_column_and_tablename(
cls,
schema: Optional[str],
tname: str,
col: Column,
col: Column[Any],
) -> DropColumnOp:
return cls(
tname,
Expand Down
4 changes: 2 additions & 2 deletions alembic/runtime/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
CompareServerDefault = Callable[
[
MigrationContext,
Column,
Column,
"Column[Any]",
"Column[Any]",
Optional[str],
Optional[FetchedValue],
Optional[str],
Expand Down
6 changes: 3 additions & 3 deletions alembic/runtime/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def config(self) -> Optional[Config]:
return None

def _compare_type(
self, inspector_column: Column, metadata_column: Column
self, inspector_column: Column[Any], metadata_column: Column
) -> bool:
if self._user_compare_type is False:
return False
Expand All @@ -728,8 +728,8 @@ def _compare_type(

def _compare_server_default(
self,
inspector_column: Column,
metadata_column: Column,
inspector_column: Column[Any],
metadata_column: Column[Any],
rendered_metadata_default: Optional[str],
rendered_column_default: Optional[str],
) -> bool:
Expand Down
Loading

0 comments on commit 6ad07e2

Please sign in to comment.