Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_cleartasks.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/openlineage/.*\.py$|
^task_sdk.*\.py$
Expand Down
37 changes: 18 additions & 19 deletions airflow-core/tests/unit/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import random

import pytest
from sqlalchemy import select
from sqlalchemy import func, select

from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_clear_task_instances(self, dag_maker):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)

ti0.refresh_from_db(session)
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_clear_task_instances_external_executor_id(self, dag_maker):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)

ti0.refresh_from_db()
Expand Down Expand Up @@ -186,12 +186,12 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
assert session.query(TaskInstanceHistory).count() == 0
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
assert session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 0
clear_task_instances(qry, session, dag_run_state=state)
session.flush()
# 2 TIs were cleared so 2 history records should be created
assert session.query(TaskInstanceHistory).count() == 2
assert session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 2

session.refresh(dr)

Expand Down Expand Up @@ -229,7 +229,7 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()

Expand Down Expand Up @@ -282,7 +282,7 @@ def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_m
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()

Expand Down Expand Up @@ -394,7 +394,7 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)

ti0.refresh_from_db(session=session)
Expand Down Expand Up @@ -477,20 +477,19 @@ def test_clear_task_instances_with_task_reschedule(self, dag_maker):
with create_session() as session:

def count_task_reschedule(ti):
return session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count()
return session.scalar(
select(func.count()).select_from(TaskReschedule).where(TaskReschedule.ti_id == ti.id)
)

assert count_task_reschedule(ti0) == 1
assert count_task_reschedule(ti1) == 1
# we use order_by(task_id) here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = (
session.query(TI)
.filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id)
.order_by(TI.task_id)
.all()
)
qry = session.scalars(
select(TI).where(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).order_by(TI.task_id)
).all()
clear_task_instances(qry, session)
assert count_task_reschedule(ti0) == 0
assert count_task_reschedule(ti1) == 1
Expand Down Expand Up @@ -531,7 +530,7 @@ def test_task_instance_history_record(self, state, state_recorded, dag_maker):
ti1.state = state
session = dag_maker.session
session.flush()
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()

Expand Down Expand Up @@ -716,10 +715,10 @@ def test_clear_task_instances_with_run_on_latest_version(self, run_on_latest_ver
new_dag_version = DagVersion.get_latest_version(dag.dag_id)

assert old_dag_version.id != new_dag_version.id
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session, run_on_latest_version=run_on_latest_version)
session.commit()
dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one()
dr = session.scalar(select(DagRun).where(DagRun.dag_id == dag.dag_id))
if run_on_latest_version:
assert dr.created_dag_version_id == new_dag_version.id
assert dr.bundle_version == new_dag_version.bundle_version
Expand Down