Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 20 additions & 41 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from airflow.executors import workloads
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import perform_heartbeat
from airflow.models import DagBag
from airflow.models.trigger import Trigger
from airflow.sdk.api.datamodels._generated import HITLDetailResponse
from airflow.sdk.execution_time.comms import (
Expand Down Expand Up @@ -606,45 +605,6 @@ def update_triggers(self, requested_trigger_ids: set[int]):
trigger set.
"""
render_log_fname = log_filename_template_renderer()
dag_bag = DagBag(collect_dags=False)

def expand_start_trigger_args(trigger: Trigger) -> Trigger:
task = dag_bag.get_dag(trigger.task_instance.dag_id).get_task(trigger.task_instance.task_id)
if task.template_fields:
trigger.task_instance.refresh_from_task(task)
context = trigger.task_instance.get_template_context()
task.render_template_fields(context=context)
start_trigger_args = task.expand_start_trigger_args(context=context)
if start_trigger_args:
trigger.kwargs = start_trigger_args.trigger_kwargs
return trigger

def create_workload(trigger: Trigger) -> workloads.RunTrigger:
if trigger.task_instance:
log_path = render_log_fname(ti=trigger.task_instance)

trigger = expand_start_trigger_args(trigger)

ser_ti = workloads.TaskInstance.model_validate(trigger.task_instance, from_attributes=True)
# When producing logs from TIs, include the job id producing the logs to disambiguate it.
self.logger_cache[new_id] = TriggerLoggingFactory(
log_path=f"{log_path}.trigger.{self.job.id}.log",
ti=ser_ti, # type: ignore
)

return workloads.RunTrigger(
classpath=trigger.classpath,
id=new_id,
encrypted_kwargs=trigger.encrypted_kwargs,
ti=ser_ti,
timeout_after=trigger.task_instance.trigger_timeout,
)
return workloads.RunTrigger(
classpath=trigger.classpath,
id=new_id,
encrypted_kwargs=trigger.encrypted_kwargs,
ti=None,
)

known_trigger_ids = (
self.running_triggers.union(x[0] for x in self.events)
Expand Down Expand Up @@ -682,7 +642,26 @@ def create_workload(trigger: Trigger) -> workloads.RunTrigger:
)
continue

workload = create_workload(new_trigger_orm)
workload = workloads.RunTrigger(
classpath=new_trigger_orm.classpath,
id=new_id,
encrypted_kwargs=new_trigger_orm.encrypted_kwargs,
ti=None,
)
if new_trigger_orm.task_instance:
log_path = render_log_fname(ti=new_trigger_orm.task_instance)

ser_ti = workloads.TaskInstance.model_validate(
new_trigger_orm.task_instance, from_attributes=True
)
# When producing logs from TIs, include the job id producing the logs to disambiguate it.
self.logger_cache[new_id] = TriggerLoggingFactory(
log_path=f"{log_path}.trigger.{self.job.id}.log",
ti=ser_ti, # type: ignore
)

workload.ti = ser_ti
workload.timeout_after = new_trigger_orm.task_instance.trigger_timeout

to_create.append(workload)

Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@
from airflow.models import TaskInstance
from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink, Context, StartTriggerArgs
from airflow.sdk import BaseOperatorLink, Context
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import ParamsDict
from airflow.task.trigger_rule import TriggerRule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs

Operator: TypeAlias = "SerializedBaseOperator | MappedOperator"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.sdk import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler?
from airflow.sdk.bases.trigger import StartTriggerArgs
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.asset import (
Expand Down Expand Up @@ -84,7 +83,7 @@
from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.triggers.base import BaseTrigger
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
ConnectionAccessor,
Expand Down
23 changes: 10 additions & 13 deletions airflow-core/src/airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import abc
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import timedelta
from typing import Annotated, Any

import structlog
Expand All @@ -31,25 +33,20 @@
)

from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
from airflow.utils.state import TaskInstanceState

log = structlog.get_logger(logger_name=__name__)


def __getattr__(name: str):
if name == "StartTriggerArgs":
import warnings
@dataclass
class StartTriggerArgs:
"""Arguments required for start task execution from triggerer."""

warnings.warn(
"airflow.triggers.base.StartTriggerArgs is deprecated. "
"Use airflow.sdk.bases.trigger.StartTriggerArgs instead.",
DeprecationWarning,
stacklevel=2,
)
return import_string(f"airflow.sdk.bases.trigger.{name}")

raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
next_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None


class BaseTrigger(abc.ABC, LoggingMixin):
Expand Down
68 changes: 11 additions & 57 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
TriggerRunnerSupervisor,
messages,
)
from airflow.models import DagBag, DagModel, DagRun, TaskInstance, Trigger
from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
Expand Down Expand Up @@ -128,15 +128,6 @@ def create_trigger_in_db(session, trigger, operator=None):
return dag_model, run, trigger_orm, task_instance


def mock_dag_bag(mock_dag_bag_cls, task_instance: TaskInstance):
mock_dag = MagicMock(spec=DAG)
mock_dag.get_task.return_value = task_instance.task

mock_dag_bag = MagicMock(spec=DagBag)
mock_dag_bag.get_dag.return_value = mock_dag
mock_dag_bag_cls.return_value = mock_dag_bag


def test_is_needed(session):
"""Checks the triggerer-is-needed logic"""
# No triggers, no need
Expand Down Expand Up @@ -215,8 +206,7 @@ def builder(job=None):
return builder


@patch("airflow.jobs.triggerer_job_runner.DagBag")
def test_trigger_lifecycle(mock_dag_bag_cls, spy_agency: SpyAgency, session, testing_dag_bundle):
def test_trigger_lifecycle(spy_agency: SpyAgency, session, testing_dag_bundle):
"""
Checks that the triggerer will correctly see a new Trigger in the database
and send it to the trigger runner, and then delete it when it vanishes.
Expand All @@ -225,8 +215,6 @@ def test_trigger_lifecycle(mock_dag_bag_cls, spy_agency: SpyAgency, session, tes
# (we want to avoid it firing and deleting itself)
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger)
mock_dag_bag(mock_dag_bag_cls, task_instance)

# Make a TriggererJobRunner and have it retrieve DB tasks
trigger_runner_supervisor = TriggerRunnerSupervisor.start(job=Job(id=12345), capacity=10)

Expand Down Expand Up @@ -409,10 +397,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):


@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.DagBag")
async def test_trigger_create_race_condition_38599(
mock_dag_bag_cls, session, supervisor_builder, testing_dag_bundle
):
async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle):
"""
This verifies the resolution of race condition documented in github issue #38599.
More details in the issue description.
Expand Down Expand Up @@ -441,14 +426,10 @@ async def test_trigger_create_race_condition_38599(
dm = DagModel(dag_id="test-dag", bundle_name=bundle_name)
session.add(dm)
SerializedDagModel.write_dag(dag, bundle_name=bundle_name)
dag_run = DagRun(
dag.dag_id, run_id="abc", run_type="manual", start_date=timezone.utcnow(), run_after=timezone.utcnow()
)
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow())
dag_version = DagVersion.get_latest_version(dag.dag_id)
task = PythonOperator(task_id="dummy-task", python_callable=print)
task.dag = dag
ti = TaskInstance(
task,
PythonOperator(task_id="dummy-task", python_callable=print),
run_id=dag_run.run_id,
state=TaskInstanceState.DEFERRED,
dag_version_id=dag_version.id,
Expand All @@ -465,8 +446,6 @@ async def test_trigger_create_race_condition_38599(

session.commit()

mock_dag_bag(mock_dag_bag_cls, ti)

supervisor1 = supervisor_builder(job1)
supervisor2 = supervisor_builder(job2)

Expand Down Expand Up @@ -600,8 +579,7 @@ async def test_trigger_failing():
info["task"].cancel()


@patch("airflow.jobs.triggerer_job_runner.DagBag")
def test_failed_trigger(mock_dag_bag_cls, session, dag_maker, supervisor_builder):
def test_failed_trigger(session, dag_maker, supervisor_builder):
"""
Checks that the triggerer will correctly fail task instances that depend on
triggers that can't even be loaded.
Expand All @@ -624,8 +602,6 @@ def test_failed_trigger(mock_dag_bag_cls, session, dag_maker, supervisor_builder
task_instance.trigger_id = trigger_orm.id
session.commit()

mock_dag_bag(mock_dag_bag_cls, task_instance)

supervisor: TriggerRunnerSupervisor = supervisor_builder()

supervisor.load_triggers()
Expand Down Expand Up @@ -771,8 +747,7 @@ def handle_events(self):

@pytest.mark.asyncio
@pytest.mark.execution_timeout(20)
@patch("airflow.jobs.triggerer_job_runner.DagBag")
async def test_trigger_can_call_variables_connections_and_xcoms_methods(mock_dag_bag_cls, session, dag_maker):
async def test_trigger_can_call_variables_connections_and_xcoms_methods(session, dag_maker):
"""Checks that the trigger will successfully call Variables, Connections and XComs methods."""
# Create the test DAG and task
with dag_maker(dag_id="trigger_accessing_variable_connection_and_xcom", session=session):
Expand Down Expand Up @@ -834,8 +809,6 @@ async def test_trigger_can_call_variables_connections_and_xcoms_methods(mock_dag
session.add(job)
session.commit()

mock_dag_bag(mock_dag_bag_cls, task_instance)

supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None)
supervisor.run()

Expand Down Expand Up @@ -906,10 +879,7 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]:

@pytest.mark.asyncio
@pytest.mark.execution_timeout(10)
@patch("airflow.jobs.triggerer_job_runner.DagBag")
async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(
mock_dag_bag_cls, session, dag_maker
):
async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(session, dag_maker):
"""Checks that the trigger will successfully fetch the count of trigger DAG runs."""
# Create the test DAG and task
with dag_maker(dag_id="trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable", session=session):
Expand Down Expand Up @@ -940,8 +910,6 @@ async def test_trigger_can_fetch_trigger_dag_run_count_and_state_in_deferrable(
session.add(job)
session.commit()

mock_dag_bag(mock_dag_bag_cls, task_instance)

supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None)
supervisor.run()

Expand Down Expand Up @@ -1002,8 +970,7 @@ async def run(self, **args) -> AsyncIterator[TriggerEvent]:

@pytest.mark.asyncio
@pytest.mark.execution_timeout(10)
@patch("airflow.jobs.triggerer_job_runner.DagBag")
async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_bag_cls, session, dag_maker):
async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(session, dag_maker):
"""Checks that the trigger will successfully fetch the count of DAG runs, Task count and task states."""
# Create the test DAG and task
with dag_maker(dag_id="parent_dag", session=session):
Expand Down Expand Up @@ -1044,8 +1011,6 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_b
session.add(job)
session.commit()

mock_dag_bag(mock_dag_bag_cls, task_instance)

supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, logger=None)
supervisor.run()

Expand All @@ -1058,19 +1023,14 @@ async def test_trigger_can_fetch_dag_run_count_ti_count_in_deferrable(mock_dag_b
}


@patch("airflow.jobs.triggerer_job_runner.DagBag")
def test_update_triggers_prevents_duplicate_creation_queue_entries(
mock_dag_bag_cls, session, supervisor_builder
):
def test_update_triggers_prevents_duplicate_creation_queue_entries(session, supervisor_builder):
"""
Test that update_triggers prevents adding triggers to the creation queue
if they are already queued for creation.
"""
trigger = TimeDeltaTrigger(datetime.timedelta(days=7))
dag_model, run, trigger_orm, task_instance = create_trigger_in_db(session, trigger)

mock_dag_bag(mock_dag_bag_cls, task_instance)

supervisor = supervisor_builder()

# First call to update_triggers should add the trigger to creating_triggers
Expand All @@ -1092,9 +1052,8 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries(
assert not any(trigger_id == trigger_orm.id for trigger_id, _ in supervisor.failed_triggers)


@patch("airflow.jobs.triggerer_job_runner.DagBag")
def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple_triggers(
mock_dag_bag_cls, session, supervisor_builder, dag_maker
session, supervisor_builder, dag_maker
):
"""
Test that update_triggers prevents adding multiple triggers to the creation queue
Expand All @@ -1105,8 +1064,6 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple

dag_model1, run1, trigger_orm1, task_instance1 = create_trigger_in_db(session, trigger1)

mock_dag_bag(mock_dag_bag_cls, task_instance1)

with dag_maker("test_dag_2"):
EmptyOperator(task_id="test_ti_2")

Expand All @@ -1115,9 +1072,6 @@ def test_update_triggers_prevents_duplicate_creation_queue_entries_with_multiple
ti2 = run2.task_instances[0]
session.add(trigger_orm2)
session.flush()

mock_dag_bag(mock_dag_bag_cls, ti2)

ti2.trigger_id = trigger_orm2.id
session.merge(ti2)
session.flush()
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
from airflow.sdk import BaseOperator, setup, task, task_group, teardown
from airflow.sdk.bases.trigger import StartTriggerArgs
from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.stats import Stats
from airflow.task.trigger_rule import TriggerRule
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.span_status import SpanStatus
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.thread_safe_dict import ThreadSafeDict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from airflow.sdk import AssetAlias, BaseHook, teardown
from airflow.sdk.bases.decorator import DecoratedOperator
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.trigger import StartTriggerArgs
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.operator_resources import Resources
Expand All @@ -83,6 +82,7 @@
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.timetables.simple import NullTimetable, OnceTimetable
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.module_loading import qualname

from tests_common.test_utils.config import conf_vars
Expand Down
Loading