Skip to content

Commit

Permalink
local task job: add timeout, to not kill on_task_instance_success lis…
Browse files Browse the repository at this point in the history
…tener prematurely

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Jun 10, 2024
1 parent e9d8222 commit 3d4661d
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 3 deletions.
8 changes: 8 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ core:
type: string
example: ~
default: "downstream"
task_success_overtime:
description: |
Maximum possible time (in seconds) that task will have for execution of auxiliary processes
(like listeners, mini scheduler...) after task is marked as success..
version_added: 2.10.0
type: integer
example: ~
default: "20"
default_task_execution_timeout:
description: |
The default task execution_timeout value for the operators. Expected an integer value to
Expand Down
9 changes: 8 additions & 1 deletion airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
self.terminating = False

self._state_change_checks = 0
# time spend after task completed, but before it exited - used to measure listener execution time
self._overtime = 0.0

def _execute(self) -> int | None:
from airflow.task.task_runner import get_task_runner
Expand Down Expand Up @@ -195,7 +197,6 @@ def sigusr2_debug_handler(signum, frame):
self.job.heartrate if self.job.heartrate is not None else heartbeat_time_limit,
),
)

return_code = self.task_runner.return_code(timeout=max_wait_time)
if return_code is not None:
self.handle_task_exit(return_code)
Expand Down Expand Up @@ -290,6 +291,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
)
raise AirflowException("PID of job runner does not match")
elif self.task_runner.return_code() is None and hasattr(self.task_runner, "process"):
self._overtime = (timezone.utcnow() - (ti.end_date or timezone.utcnow())).total_seconds()
if ti.state == TaskInstanceState.SKIPPED:
# A DagRun timeout will cause tasks to be externally marked as skipped.
dagrun = ti.get_dagrun(session=session)
Expand All @@ -303,6 +305,11 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
if dagrun_timeout and execution_time > dagrun_timeout:
self.log.warning("DagRun timed out after %s.", execution_time)

# If process still runs after being marked as success, let it run until configured overtime
if ti.state == TaskInstanceState.SUCCESS and self._overtime < conf.getint(
"core", "task_success_overtime"
):
return
# potential race condition, the _run_raw_task commits `success` or other state
# but task_runner does not exit right away due to slow process shutdown or any other reasons
# let's do a throttle here, if the above case is true, the handle_task_exit will handle it
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def on_running():
dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None
)
data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None

redacted_event = self.adapter.start_task(
run_id=task_uuid,
job_name=get_job_name(task),
Expand Down
17 changes: 17 additions & 0 deletions tests/dags/test_mark_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# under the License.
from __future__ import annotations

import time
from datetime import datetime
from time import sleep

from airflow.models.dag import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import utcnow

DEFAULT_DATE = datetime(2016, 1, 1)

Expand All @@ -41,11 +43,22 @@ def success_callback(context):
assert context["dag_run"].dag_id == dag_id


def sleep_execution():
time.sleep(1)


def slow_execution():
import re

re.match(r"(a?){30}a{30}", "a" * 30)


def test_mark_success_no_kill(ti):
assert ti.state == State.RUNNING
# Simulate marking this successful in the UI
with create_session() as session:
ti.state = State.SUCCESS
ti.end_date = utcnow()
session.merge(ti)
session.commit()
# The below code will not run as heartbeat will detect change of state
Expand Down Expand Up @@ -103,3 +116,7 @@ def test_mark_skipped_externally(ti):
PythonOperator(task_id="test_mark_skipped_externally", python_callable=test_mark_skipped_externally, dag=dag)

PythonOperator(task_id="dummy", python_callable=lambda: True, dag=dag)

PythonOperator(task_id="slow_execution", python_callable=slow_execution, dag=dag)

PythonOperator(task_id="sleep_execution", python_callable=sleep_execution, dag=dag)
115 changes: 114 additions & 1 deletion tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
import pytest

from airflow import settings
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, LocalTaskJobRunner
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -322,6 +323,7 @@ def test_heartbeat_failed_fast(self):
delta = (time2 - time1).total_seconds()
assert abs(delta - job.heartrate) < 0.8

@conf_vars({("core", "task_success_overtime"): "1"})
def test_mark_success_no_kill(self, caplog, get_test_dag, session):
"""
Test that ensures that mark_success in the UI doesn't cause
Expand Down Expand Up @@ -553,6 +555,7 @@ def test_failure_callback_called_by_airflow_run_raw_process(self, monkeypatch, t
assert m, "pid expected in output."
assert os.getpid() != int(m.group(1))

@conf_vars({("core", "task_success_overtime"): "5"})
def test_mark_success_on_success_callback(self, caplog, get_test_dag):
"""
Test that ensures that where a task is marked success in the UI
Expand Down Expand Up @@ -583,6 +586,116 @@ def test_mark_success_on_success_callback(self, caplog, get_test_dag):
"State of this instance has been externally set to success. Terminating instance." in caplog.text
)

def test_success_listeners_executed(self, caplog, get_test_dag):
"""
Test that ensures that when listeners are executed, the task is not killed before they finish
or timeout
"""
from tests.listeners import slow_listener

lm = get_listener_manager()
lm.clear()
lm.add_listener(slow_listener)

dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="sleep_execution")

ti = dr.get_task_instance(task.task_id)
ti.refresh_from_task(task)

job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
with timeout(30):
run_job(job=job, execute_callable=job_runner._execute)
ti.refresh_from_db()
assert (
"State of this instance has been externally set to success. Terminating instance."
not in caplog.text
)
lm.clear()

def test_success_slow_listeners_executed_kill(self, caplog, get_test_dag):
"""
Test that ensures that when there are too slow listeners, the task is killed
"""
from tests.listeners import very_slow_listener

lm = get_listener_manager()
lm.clear()
lm.add_listener(very_slow_listener)

dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="sleep_execution")

ti = dr.get_task_instance(task.task_id)
ti.refresh_from_task(task)

job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
with timeout(30):
run_job(job=job, execute_callable=job_runner._execute)
ti.refresh_from_db()
assert (
"State of this instance has been externally set to success. Terminating instance." in caplog.text
)
lm.clear()

def test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, caplog, get_test_dag):
"""
Test that ensures that when there are listeners, but the task is taking a long time anyways,
it's not killed by the overtime mechanism.
"""
from tests.listeners import slow_listener

lm = get_listener_manager()
lm.clear()
lm.add_listener(slow_listener)

dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="slow_execution")

ti = dr.get_task_instance(task.task_id)
ti.refresh_from_task(task)

job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
with pytest.raises(AirflowTaskTimeout):
with timeout(30):
run_job(job=job, execute_callable=job_runner._execute)
ti.refresh_from_db()
assert (
"State of this instance has been externally set to success. Terminating instance."
not in caplog.text
)
lm.clear()

@pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL])
def test_process_os_signal_calls_on_failure_callback(
self, monkeypatch, tmp_path, get_test_dag, signal_type
Expand Down
26 changes: 26 additions & 0 deletions tests/listeners/slow_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time

from airflow.listeners import hookimpl


@hookimpl
def on_task_instance_success(previous_state, task_instance, session):
time.sleep(5)
26 changes: 26 additions & 0 deletions tests/listeners/very_slow_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time

from airflow.listeners import hookimpl


@hookimpl
def on_task_instance_success(previous_state, task_instance, session):
time.sleep(30)

0 comments on commit 3d4661d

Please sign in to comment.