Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,7 @@ def ti_run(
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
.filter(
TaskReschedule.dag_id == ti.dag_id,
TaskReschedule.task_id == ti_id_str,
TaskReschedule.run_id == ti.run_id,
# TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks
TaskReschedule.try_number == ti.try_number,
)
.filter(TaskReschedule.ti_id == ti_id_str, TaskReschedule.try_number == ti.try_number)
.scalar()
or 0
)
Expand Down Expand Up @@ -360,14 +354,11 @@ def ti_update_state(
actual_start_date = timezone.utcnow()
session.add(
TaskReschedule(
task_instance.task_id,
task_instance.dag_id,
task_instance.run_id,
task_instance.id,
task_instance.try_number,
actual_start_date,
ti_patch_payload.end_date,
ti_patch_payload.reschedule_date,
task_instance.map_index,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Use ti_id as FK to TaskReschedule.

Revision ID: d469d27e2a64
Revises: 16f7f5ee874e
Create Date: 2025-03-06 16:04:49.106274

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = "d469d27e2a64"
down_revision = "16f7f5ee874e"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Use ti_id as FK to TaskReschedule."""
dialect_name = op.get_context().dialect.name
with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"ti_id", sa.String(length=36).with_variant(postgresql.UUID(), "postgresql"), nullable=True
)
)
if dialect_name == "postgresql":
op.execute("""
UPDATE task_reschedule SET ti_id = task_instance.id
FROM task_instance
WHERE task_reschedule.task_id = task_instance.task_id
AND task_reschedule.dag_id = task_instance.dag_id
AND task_reschedule.run_id = task_instance.run_id
AND task_reschedule.map_index = task_instance.map_index
""")
elif dialect_name == "mysql":
op.execute("""
UPDATE task_reschedule tir
JOIN task_instance ti ON
tir.task_id = ti.task_id
AND tir.dag_id = ti.dag_id
AND tir.run_id = ti.run_id
AND tir.map_index = ti.map_index
SET tir.ti_id = ti.id
""")
else:
op.execute("""
UPDATE task_reschedule
SET ti_id = (SELECT id FROM task_instance WHERE task_reschedule.task_id = task_instance.task_id
AND task_reschedule.dag_id = task_instance.dag_id
AND task_reschedule.run_id = task_instance.run_id
AND task_reschedule.map_index = task_instance.map_index)
""")
with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
batch_op.alter_column(
"ti_id",
nullable=False,
existing_type=sa.String(length=36).with_variant(postgresql.UUID(), "postgresql"),
)
batch_op.drop_constraint("task_reschedule_ti_fkey", type_="foreignkey")
batch_op.drop_constraint("task_reschedule_dr_fkey", type_="foreignkey")
batch_op.create_foreign_key(
"task_reschedule_ti_fkey", "task_instance", ["ti_id"], ["id"], ondelete="CASCADE"
)

batch_op.drop_index("idx_task_reschedule_dag_run")
batch_op.drop_index("idx_task_reschedule_dag_task_run")
batch_op.drop_column("map_index")
batch_op.drop_column("dag_id")
batch_op.drop_column("task_id")
batch_op.drop_column("run_id")


def downgrade():
"""Unapply Use ti_id as FK to TaskReschedule."""
dialect_name = op.get_context().dialect.name
with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
batch_op.drop_constraint("task_reschedule_ti_fkey", type_="foreignkey")
batch_op.add_column(sa.Column("run_id", StringID(), nullable=True))
batch_op.add_column(sa.Column("task_id", StringID(), nullable=True))
batch_op.add_column(sa.Column("dag_id", StringID(), nullable=True))
batch_op.add_column(
sa.Column(
"map_index",
sa.INTEGER(),
server_default=sa.text("-1"),
nullable=False,
)
)
# fill the task_id, dag_id, run_id, map_index columns from taskinstance
if dialect_name == "postgresql":
op.execute("""
UPDATE task_reschedule
SET dag_id = task_instance.dag_id,
task_id = task_instance.task_id,
run_id = task_instance.run_id,
map_index = task_instance.map_index
FROM task_instance
WHERE task_reschedule.ti_id = task_instance.id
""")
elif dialect_name == "mysql":
op.execute("""
UPDATE task_reschedule tir
JOIN task_instance ti ON
tir.ti_id = ti.id
SET tir.dag_id = ti.dag_id,
tir.task_id = ti.task_id,
tir.run_id = ti.run_id,
tir.map_index = ti.map_index
""")
else:
op.execute("""
UPDATE task_reschedule
SET dag_id = (SELECT dag_id FROM task_instance WHERE task_reschedule.ti_id = task_instance.id),
task_id = (SELECT task_id FROM task_instance WHERE task_reschedule.ti_id = task_instance.id),
run_id = (SELECT run_id FROM task_instance WHERE task_reschedule.ti_id = task_instance.id),
map_index = (SELECT map_index FROM task_instance WHERE task_reschedule.ti_id = task_instance.id)
""")
with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
batch_op.drop_column("ti_id")
batch_op.alter_column("run_id", nullable=False, existing_type=StringID())
batch_op.alter_column("task_id", nullable=False, existing_type=StringID())
batch_op.alter_column("dag_id", nullable=False, existing_type=StringID())
batch_op.alter_column("map_index", nullable=False, existing_type=sa.INTEGER())

batch_op.create_foreign_key(
"task_reschedule_dr_fkey",
"dag_run",
["dag_id", "run_id"],
["dag_id", "run_id"],
ondelete="CASCADE",
)
batch_op.create_foreign_key(
"task_reschedule_ti_fkey",
"task_instance",
["dag_id", "task_id", "run_id", "map_index"],
["dag_id", "task_id", "run_id", "map_index"],
ondelete="CASCADE",
)
batch_op.create_index(
"idx_task_reschedule_dag_task_run", ["dag_id", "task_id", "run_id", "map_index"], unique=False
)
batch_op.create_index("idx_task_reschedule_dag_run", ["dag_id", "run_id"], unique=False)
49 changes: 7 additions & 42 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,13 @@ def clear_task_instances(
If set to False, DagRuns state will not be changed.
:param dag: DAG object
"""
# Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
)
# taskinstance uuids:
task_instance_ids: list[str] = []
dag_bag = DagBag(read_dags_from_db=True)
from airflow.models.taskinstancehistory import TaskInstanceHistory

for ti in tis:
task_instance_ids.append(ti.id)
TaskInstanceHistory.record_ti(ti, session)
ti.try_id = uuid7()
if ti.state == TaskInstanceState.RUNNING:
Expand All @@ -476,40 +475,10 @@ def clear_task_instances(
ti.external_executor_id = None
ti.clear_next_method_args()
session.merge(ti)
task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)

if task_id_by_key:
if task_instance_ids:
# Clear all reschedules related to the ti to clear

# This is an optimization for the common case where all tis are for a small number
# of dag_id, run_id, try_number, and map_index. Use a nested dict of dag_id,
# run_id, try_number, map_index, and task_id to construct the where clause in a
# hierarchical manner. This speeds up the delete statement by more than 40x for
# large number of tis (50k+).
conditions = or_(
and_(
TR.dag_id == dag_id,
or_(
and_(
TR.run_id == run_id,
or_(
and_(
TR.map_index == map_index,
or_(
and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
for try_number, task_ids in task_tries.items()
),
)
for map_index, task_tries in map_indexes.items()
),
)
for run_id, map_indexes in run_ids.items()
),
)
for dag_id, run_ids in task_id_by_key.items()
)

delete_qry = TR.__table__.delete().where(conditions)
delete_qry = TR.__table__.delete().where(TR.ti_id.in_(task_instance_ids))
session.execute(delete_qry)

if dag_run_state is not False and tis:
Expand Down Expand Up @@ -1600,14 +1569,11 @@ def _handle_reschedule(
# see https://github.com/apache/airflow/pull/21362 for more info
session.add(
TaskReschedule(
ti.task_id,
ti.dag_id,
ti.run_id,
ti.id,
ti.try_number,
actual_start_date,
ti.end_date,
reschedule_exception.reschedule_date,
ti.map_index,
)
)
session.commit()
Expand Down Expand Up @@ -3630,12 +3596,11 @@ def clear_db_references(self, session: Session):
from airflow.models.renderedtifields import RenderedTaskInstanceFields

tables: list[type[TaskInstanceDependencies]] = [
TaskReschedule,
XCom,
RenderedTaskInstanceFields,
TaskMap,
]
tables_by_id: list[type[Base]] = [TaskInstanceNote]
tables_by_id: list[type[Base]] = [TaskInstanceNote, TaskReschedule]
for table in tables:
session.execute(
delete(table).where(
Expand Down
Loading