Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,6 @@ def post_clear_task_instances(
clear_task_instances(
task_instances,
session,
dag,
DagRunState.QUEUED if reset_dag_runs else False,
)

Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def clear(
# definition code
assert isinstance(self.dag, SchedulerDAG)

clear_task_instances(results, session, dag=self.dag)
clear_task_instances(results, session)
session.commit()
return count

Expand Down
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset
from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as task_sdk_dag_decorator
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.security import permissions
from airflow.settings import json
from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
Expand Down Expand Up @@ -468,6 +467,7 @@ def _upgrade_outdated_dag_access_control(access_control=None):
return None

from airflow.providers.fab import __version__ as FAB_VERSION
from airflow.providers.fab.www.security import permissions

updated_access_control = {}
for role, perms in access_control.items():
Expand Down Expand Up @@ -1526,7 +1526,6 @@ def clear(
clear_task_instances(
list(tis),
session,
dag=self,
dag_run_state=dag_run_state,
)
else:
Expand Down
22 changes: 16 additions & 6 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import AssetActive, AssetEvent, AssetModel
from airflow.models.base import Base, StringID, TaskInstanceDependencies
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -255,7 +254,6 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None
def clear_task_instances(
tis: list[TaskInstance],
session: Session,
dag: DAG | None = None,
dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
) -> None:
"""
Expand All @@ -271,11 +269,13 @@ def clear_task_instances(
:param session: current session
:param dag_run_state: state to set finished DagRuns to.
If set to False, DagRuns state will not be changed.
:param dag: DAG object

:meta private:
"""
# taskinstance uuids:
task_instance_ids: list[str] = []
dag_bag = DagBag(read_dags_from_db=True)
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

scheduler_dagbag = SchedulerDagBag()

for ti in tis:
task_instance_ids.append(ti.id)
Expand All @@ -285,7 +285,10 @@ def clear_task_instances(
# the task is terminated and becomes eligible for retry.
ti.state = TaskInstanceState.RESTARTING
else:
ti_dag = dag if dag and dag.dag_id == ti.dag_id else dag_bag.get_dag(ti.dag_id, session=session)
dr = ti.dag_run
ti_dag = scheduler_dagbag.get_dag(dag_run=dr, session=session)
if not ti_dag:
log.warning("No serialized dag found for dag '%s'", dr.dag_id)
task_id = ti.task_id
if ti_dag and ti_dag.has_task(task_id):
task = ti_dag.get_task(task_id)
Expand Down Expand Up @@ -326,6 +329,13 @@ def clear_task_instances(
if dr.state in State.finished_dr_states:
dr.state = dag_run_state
dr.start_date = timezone.utcnow()
dr_dag = scheduler_dagbag.get_dag(dag_run=dr, session=session)
if not dr_dag:
log.warning("No serialized dag found for dag '%s'", dr.dag_id)
if dr_dag and not dr_dag.disable_bundle_versioning:
bundle_version = dr.dag_model.bundle_version
if bundle_version is not None:
dr.bundle_version = bundle_version
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from tests_common.test_utils.db import clear_db_runs

pytestmark = pytest.mark.db_test
pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag]


class TestTaskInstancesLog:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2271,8 +2271,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, t

# dag (3rd argument) is a different session object. Manually asserting that the dag_id
# is the same.
mock_clearti.assert_called_once_with([], mock.ANY, mock.ANY, DagRunState.QUEUED)
assert mock_clearti.call_args[0][2].dag_id == dag_id
mock_clearti.assert_called_once_with([], mock.ANY, DagRunState.QUEUED)

def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, session):
"""Test that dagrun is running when invalid task_ids are passed to clearTaskInstances API."""
Expand Down
57 changes: 56 additions & 1 deletion airflow-core/tests/unit/models/test_backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pytest
from sqlalchemy import select

from airflow.models import DagRun, TaskInstance
from airflow.models import DagModel, DagRun, TaskInstance
from airflow.models.backfill import (
AlreadyRunningBackfill,
Backfill,
Expand Down Expand Up @@ -152,6 +152,61 @@ def test_create_backfill_simple(reverse, existing, dag_maker, session):
assert all(x.conf == expected_run_conf for x in dag_runs)


def test_create_backfill_clear_existing_bundle_version(dag_maker, session):
"""
Verify that when backfill clears an existing dag run, bundle version is cleared.
"""
# two that will be reprocessed, and an old one not to be processed by backfill
existing = ["1985-01-01", "2021-01-02", "2021-01-03"]
run_ids = {d: f"scheduled_{d}" for d in existing}
with dag_maker(schedule="@daily") as dag:
PythonOperator(task_id="hi", python_callable=print)

dag_model = session.scalar(select(DagModel).where(DagModel.dag_id == dag.dag_id))
first_bundle_version = "bundle_VclmpcTdXv"
dag_model.bundle_version = first_bundle_version
session.commit()
for date in existing:
dag_maker.create_dagrun(
run_id=run_ids[date], logical_date=timezone.parse(date), session=session, state="failed"
)
session.commit()

# update bundle version
new_bundle_version = "bundle_VclmpcTdXv-2"
dag_model.bundle_version = new_bundle_version
session.commit()

# verify that existing dag runs still have the first bundle version
dag_runs = list(session.scalars(select(DagRun).where(DagRun.dag_id == dag.dag_id)))
assert [x.bundle_version for x in dag_runs] == 3 * [first_bundle_version]
assert [x.state for x in dag_runs] == 3 * ["failed"]
session.commit()
_create_backfill(
dag_id=dag.dag_id,
from_date=pendulum.parse("2021-01-01"),
to_date=pendulum.parse("2021-01-05"),
max_active_runs=10,
reverse=False,
dag_run_conf=None,
reprocess_behavior=ReprocessBehavior.FAILED,
)
session.commit()

# verify that the old dag run (not included in backfill) still has first bundle version
# but the latter 5, which are included in the backfill, have the latest bundle version
dag_runs = sorted(
session.scalars(
select(DagRun).where(
DagRun.dag_id == dag.dag_id,
),
),
key=lambda x: x.logical_date,
)
expected = [first_bundle_version] + 5 * [new_bundle_version]
assert [x.bundle_version for x in dag_runs] == expected


@pytest.mark.parametrize(
"reprocess_behavior, num_in_b, exc_reasons",
[
Expand Down
Loading