Skip to content

Commit

Permalink
Refactor dangling row check to use SQLA queries (apache#19808)
Browse files Browse the repository at this point in the history
This is a prepaoratory refactor to have the move dangling rows
pre-upgrade check make better use of the SQLA Queries -- this is needed
because in a future PR we will add a check for dangling XCom rows, and
that will need to conditionally join against DagRun to get
execution_date (depending on if it is run pre- or post-2.2).

This has been tested with Postgres 9.6, SQLite, MSSQL 2017 and MySQL 5.7

codespell didn't like `froms` as it thinks it is a typo of forms, and
most other cases it would be, except here. Codespell doesn't currently
have a method of ignoring a _single_ line without ignoring the word
everywhere (which we don't want to do) so I have to ignore the exact
_line_. Sad panda
  • Loading branch information
ashb authored Jan 24, 2022
1 parent 31db10b commit cecd4c8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 75 deletions.
2 changes: 2 additions & 0 deletions .codespellignorelines
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f"DELETE {source_table} FROM { ', '.join(_from_name(tbl) for tbl in stmt.froms) }"
for frm in source_query.selectable.froms:
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ repos:
args:
- --ignore-words=docs/spelling_wordlist.txt
- --skip=docs/*/commits.rst,airflow/providers/*/*.rst,*.lock,INTHEWILD.md,*.min.js,docs/apache-airflow/pipeline_example.csv
- --exclude-file=.codespellignorelines
- repo: local
hooks:
- id: autoflake
Expand Down
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
.coverage
.coveragerc
.codecov.yml
.codespellignorelines
.eslintrc
.eslintignore
.flake8
Expand Down
148 changes: 73 additions & 75 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import time
from tempfile import gettempdir
from typing import Any, Callable, Iterable, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Tuple

from sqlalchemy import Table, exc, func, inspect, or_, text
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -62,6 +62,10 @@
from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
from airflow.version import version

if TYPE_CHECKING:
from sqlalchemy.orm import Query


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -799,69 +803,7 @@ def _format_dangling_error(source_table, target_table, invalid_count, reason):
)


def _move_dangling_run_data_to_new_table(session: Session, source_table: "Table", target_table_name: str):
where_clause = "where dag_id is null or run_id is null or execution_date is null"
_move_dangling_table(session, source_table, target_table_name, where_clause)


def _move_dangling_table(session, source_table: "Table", target_table_name: str, where_clause: str):
dialect_name = session.get_bind().dialect.name

delete_where = " AND ".join(
f"{source_table.name}.{c.name} = d.{c.name}" for c in source_table.primary_key.columns
)
if dialect_name == "mssql":
session.execute(
text(f"select source.* into {target_table_name} from {source_table} as source {where_clause}")
)
session.execute(
text(
f"delete from {source_table} from {source_table} join {target_table_name} AS d ON "
+ delete_where
)
)
else:
if dialect_name == "mysql":
# CREATE TABLE AS SELECT must be broken into two queries for MySQL as the single query
# approach fails when replication is enabled ("Statement violates GTID consistency")```
session.execute(text(f"create table {target_table_name} like {source_table}"))
session.execute(
text(
f"INSERT INTO {target_table_name} select source.* from {source_table} as source "
+ where_clause
)
)
# Postgres and SQLite have the same CREATE TABLE a AS SELECT ... syntax
else:
session.execute(
text(
f"create table {target_table_name} as select source.* from {source_table} as source "
+ where_clause
)
)

# But different join-delete syntax.
if dialect_name == "mysql":
session.execute(
text(
f"delete {source_table} from {source_table} join {target_table_name} as d on "
+ delete_where
)
)
elif dialect_name == "sqlite":
session.execute(
text(
f"delete from {source_table} where ROWID in (select {source_table}.ROWID from "
f"{source_table} as source join {target_table_name} as d on {delete_where})"
)
)
else:
session.execute(
text(f"delete from {source_table} using {target_table_name} as d where {delete_where}")
)


def check_run_id_null(session: Session) -> Iterable[str]:
def check_run_id_null(session) -> Iterable[str]:
import sqlalchemy.schema

metadata = sqlalchemy.schema.MetaData(session.bind)
Expand Down Expand Up @@ -891,16 +833,67 @@ def check_run_id_null(session: Session) -> Iterable[str]:
reason="with a NULL dag_id, run_id, or execution_date",
)
return
_move_dangling_run_data_to_new_table(session, dagrun_table, dagrun_dangling_table_name)
_move_dangling_data_to_new_table(
session,
dagrun_table,
dagrun_table.select(invalid_dagrun_filter),
dagrun_dangling_table_name,
)


def _move_dangling_task_data_to_new_table(session, source_table: "Table", target_table_name: str):
where_clause = """
left join dag_run as dr
on (source.dag_id = dr.dag_id and source.execution_date = dr.execution_date)
where dr.id is null
"""
_move_dangling_table(session, source_table, target_table_name, where_clause)
def _move_dangling_data_to_new_table(
session, source_table: "Table", source_query: "Query", target_table_name: str
):
from sqlalchemy import column, select, table
from sqlalchemy.sql.selectable import Join

bind = session.get_bind()
dialect_name = bind.dialect.name

# First: Create moved rows from new table
if dialect_name == "mssql":
cte = source_query.cte("source")
moved_data_tbl = table(target_table_name, *(column(c.name) for c in cte.columns))
ins = moved_data_tbl.insert().from_select(list(cte.columns), select([cte]))

stmt = ins.compile(bind=session.get_bind())
cte_sql = stmt.ctes[cte]

session.execute(f"WITH {cte_sql} SELECT source.* INTO {target_table_name} FROM source")
else:
# Postgres, MySQL and SQLite all support the same "create as select"
session.execute(
f"CREATE TABLE {target_table_name} AS {source_query.selectable.compile(bind=session.get_bind())}"
)

# Second: Now delete rows we've moved
try:
clause = source_query.whereclause
except AttributeError:
clause = source_query._whereclause

if dialect_name == "sqlite":
subq = source_query.selectable.with_only_columns([text(f'{source_table}.ROWID')])
delete = source_table.delete().where(column('ROWID').in_(subq))
elif dialect_name in ("mysql", "mssql"):
# This is not foolproof! But it works for the limited queries (with no params) that we use here
stmt = source_query.selectable

def _from_name(from_) -> str:
if isinstance(from_, Join):
return str(from_.compile(bind=bind))
return str(from_)

delete = (
f"DELETE {source_table} FROM { ', '.join(_from_name(tbl) for tbl in stmt.froms) }"
f" WHERE {clause.compile(bind=bind)}"
)
else:
for frm in source_query.selectable.froms:
if hasattr(frm, 'onclause'): # Table, or JOIN?
clause &= frm.onclause
delete = source_table.delete(clause)
session.execute(delete)


def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str]:
Expand Down Expand Up @@ -945,12 +938,12 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
source_table.c.dag_id == dagrun_table.c.dag_id,
source_table.c.execution_date == dagrun_table.c.execution_date,
)
invalid_row_count = (
invalid_rows_query = (
session.query(source_table.c.dag_id, source_table.c.task_id, source_table.c.execution_date)
.select_from(outerjoin(source_table, dagrun_table, source_to_dag_run_join_cond))
.filter(dagrun_table.c.dag_id.is_(None))
.count()
)
invalid_row_count = invalid_rows_query.count()
if invalid_row_count <= 0:
continue

Expand All @@ -964,7 +957,12 @@ def check_task_tables_without_matching_dagruns(session: Session) -> Iterable[str
)
errored = True
continue
_move_dangling_task_data_to_new_table(session, source_table, dangling_table_name)
_move_dangling_data_to_new_table(
session,
source_table,
invalid_rows_query.with_entities(*source_table.columns),
dangling_table_name,
)

if errored:
session.rollback()
Expand Down

0 comments on commit cecd4c8

Please sign in to comment.