Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_dagrun.py$|
^airflow-core/tests/unit/utils/test_db_cleanup.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/openlineage/.*\.py$|
Expand Down
126 changes: 64 additions & 62 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import pendulum
import pytest
from sqlalchemy import exists, select
from sqlalchemy import exists, func, select
from sqlalchemy.orm import joinedload

from airflow import settings
Expand Down Expand Up @@ -155,10 +155,10 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi
EmptyOperator(task_id="backfill_task_0")
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)

qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
clear_task_instances(qry, session)
session.flush()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now))
assert dr0.state == state
assert dr0.clear_number < 1

Expand All @@ -170,10 +170,10 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat
EmptyOperator(task_id="backfill_task_0")
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)

qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
clear_task_instances(qry, session)
session.flush()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now))
assert dr0.state == DagRunState.QUEUED
assert dr0.clear_number == 1

Expand Down Expand Up @@ -721,22 +721,22 @@ def test_dagrun_set_state_end_date(self, dag_maker, session):
session.merge(dr)
session.commit()

dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date

dr.set_state(DagRunState.RUNNING)
session.merge(dr)
session.commit()

dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))

assert dr_database.end_date is None

dr.set_state(DagRunState.FAILED)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))

assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
Expand Down Expand Up @@ -764,15 +764,15 @@ def test_dagrun_update_state_end_date(self, dag_maker, session):

dr.update_state()

dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date

ti_op1.set_state(state=TaskInstanceState.RUNNING, session=session)
ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session)
dr.update_state()

dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))

assert dr._state == DagRunState.RUNNING
assert dr.end_date is None
Expand All @@ -782,7 +782,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session):
ti_op2.set_state(state=TaskInstanceState.FAILED, session=session)
dr.update_state()

dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one()
dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id))

assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def test_dag_run_dag_versions_method(self, dag_maker, session):
EmptyOperator(task_id="empty")
dag_run = dag_maker.create_dagrun()

dm = session.query(DagModel).options(joinedload(DagModel.dag_versions)).one()
dm = session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions)))
assert dag_run.dag_versions[0].id == dm.dag_versions[0].id

def test_dag_run_version_number(self, dag_maker, session):
Expand All @@ -1231,7 +1231,7 @@ def test_dag_run_version_number(self, dag_maker, session):
tis[1].dag_version = dag_v
session.merge(tis[1])
session.flush()
dag_run = session.query(DagRun).filter(DagRun.run_id == dag_run.run_id).one()
dag_run = session.scalar(select(DagRun).where(DagRun.run_id == dag_run.run_id))
# Check that dag_run.version_number returns the version number of
# the latest task instance dag_version
assert dag_run.version_number == dag_v.version_number
Expand Down Expand Up @@ -1337,14 +1337,14 @@ def test_dagrun_success_deadline_prune(self, dag_maker, session):
dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id)
dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id)

assert session.query(dag_run1_deadline).scalar()
assert session.query(dag_run2_deadline).scalar()
assert session.scalar(select(dag_run1_deadline))
assert session.scalar(select(dag_run2_deadline))

session.add(dag_run1)
dag_run1.update_state()

assert not session.query(dag_run1_deadline).scalar()
assert session.query(dag_run2_deadline).scalar()
assert not session.scalar(select(dag_run1_deadline))
assert session.scalar(select(dag_run2_deadline))
assert dag_run1.state == DagRunState.SUCCESS
assert dag_run2.state == DagRunState.RUNNING

Expand Down Expand Up @@ -1399,13 +1399,12 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session):
mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal)

dr = dag_maker.create_dagrun()
indices = (
session.query(TI.map_index)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
indices = session.scalars(
select(TI.map_index)
.where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
.all()
)
assert indices == [(0,), (1,), (2,), (3,)]
).all()
assert indices == [0, 1, 2, 3]


@pytest.mark.parametrize("is_noop", [True, False])
Expand All @@ -1422,13 +1421,12 @@ def mynameis(arg):
mynameis.expand(arg=literal)

dr = dag_maker.create_dagrun()
indices = (
session.query(TI.map_index)
.filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id)
indices = session.scalars(
select(TI.map_index)
.where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
.all()
)
assert indices == [(0,), (1,), (2,), (3,)]
).all()
assert indices == [0, 1, 2, 3]


def test_mapped_literal_verify_integrity(dag_maker, session):
Expand All @@ -1444,7 +1442,7 @@ def task_2(arg2): ...

query = (
select(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
Expand Down Expand Up @@ -1483,12 +1481,11 @@ def task_2(arg2): ...
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id
dr.verify_integrity(dag_version_id=dag_version_id, session=session)

indices = (
session.query(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
indices = session.execute(
select(TI.map_index, TI.state)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
.all()
)
).all()

assert indices == [
(0, TaskInstanceState.REMOVED),
Expand All @@ -1511,7 +1508,7 @@ def task_2(arg2): ...

query = (
select(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
Expand Down Expand Up @@ -1552,7 +1549,7 @@ def task_2(arg2): ...
dr = dag_maker.create_dagrun()
query = (
select(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
Expand Down Expand Up @@ -1661,7 +1658,7 @@ def task_2(arg2): ...
dr.task_instance_scheduling_decisions(session=session)
query = (
select(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
Expand Down Expand Up @@ -1751,7 +1748,7 @@ def task_2(arg2): ...

query = (
select(TI.map_index, TI.state)
.filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
.where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
Expand Down Expand Up @@ -1786,17 +1783,17 @@ def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session):

dr = dag_maker.create_dagrun()
query = (
session.query(TI.map_index, TI.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
select(TI.map_index, TI.state)
.where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
)

assert query.all() == [(-1, None)]
assert session.execute(query).all() == [(-1, None)]

# Verify_integrity shouldn't change the result now that the TIs exist
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id
dr.verify_integrity(dag_version_id=dag_version_id, session=session)
assert query.all() == [(-1, None)]
assert session.execute(query).all() == [(-1, None)]


def test_mapped_task_group_expands_at_create(dag_maker, session):
Expand All @@ -1823,11 +1820,11 @@ def tg(x):

dr = dag_maker.create_dagrun()
query = (
session.query(TI.task_id, TI.map_index, TI.state)
.filter_by(dag_id=dr.dag_id, run_id=dr.run_id)
select(TI.task_id, TI.map_index, TI.state)
.where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.task_id, TI.map_index)
)
assert query.all() == [
assert session.execute(query).all() == [
("tg.t1", 0, None),
("tg.t1", 1, None),
# ("tg.t2", 0, None),
Expand Down Expand Up @@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker, session):
# expanded against a zero-length XCom.
assert decision.finished_tis == [ti1, ti2]

indices = (
session.query(TI.map_index, TI.state)
.filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
indices = session.execute(
select(TI.map_index, TI.state)
.where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id)
.order_by(TI.map_index)
.all()
)
).all()

assert indices == [(-1, TaskInstanceState.SKIPPED)]

Expand Down Expand Up @@ -2576,8 +2572,14 @@ def printx(x):

dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
ti = dr1.get_task_instances()[0]
filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index)
ti = session.query(TaskInstance).filter_by(**filter_kwargs).one()
ti = session.scalar(
select(TaskInstance).where(
TaskInstance.dag_id == ti.dag_id,
TaskInstance.task_id == ti.task_id,
TaskInstance.run_id == ti.run_id,
TaskInstance.map_index == ti.map_index,
)
)

tr = TaskReschedule(
ti_id=ti.id,
Expand All @@ -2598,10 +2600,10 @@ def printx(x):
XComModel.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id)
session.commit()
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
assert session.query(table).count() == 1
assert session.scalar(select(func.count()).select_from(table)) == 1
dr1.task_instance_scheduling_decisions(session)
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
assert session.query(table).count() == 0
assert session.scalar(select(func.count()).select_from(table)) == 0


def test_dagrun_with_note(dag_maker, session):
Expand All @@ -2619,14 +2621,14 @@ def the_task():
session.add(dr)
session.commit()

dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one()
dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id))
assert dr_note.content == "dag run with note"

session.delete(dr)
session.commit()

assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is None
assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one_or_none() is None
assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None
assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) is None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2655,7 +2657,7 @@ def mytask():
session.flush()
dr.update_state()
session.flush()
dr = session.query(DagRun).one()
dr = session.scalar(select(DagRun))
assert dr.state == dag_run_state


Expand Down Expand Up @@ -2694,7 +2696,7 @@ def mytask():
session.flush()
dr.update_state()
session.flush()
dr = session.query(DagRun).one()
dr = session.scalar(select(DagRun))
assert dr.state == dag_run_state


Expand Down Expand Up @@ -2729,7 +2731,7 @@ def mytask():
session.flush()
dr.update_state()
session.flush()
dr = session.query(DagRun).one()
dr = session.scalar(select(DagRun))
assert dr.state == DagRunState.FAILED


Expand Down Expand Up @@ -2765,7 +2767,7 @@ def mytask():
session.flush()
dr.update_state()
session.flush()
dr = session.query(DagRun).one()
dr = session.scalar(select(DagRun))
assert dr.state == DagRunState.FAILED


Expand Down