From 3d4661dac330592d057f0bfeee50c48a7b5733eb Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Fri, 24 May 2024 17:17:01 +0200 Subject: [PATCH] local task job: add timeout, to not kill on_task_instance_success listener prematurely Signed-off-by: Maciej Obuchowski --- airflow/config_templates/config.yml | 8 ++ airflow/jobs/local_task_job_runner.py | 9 +- .../providers/openlineage/plugins/listener.py | 1 - tests/dags/test_mark_state.py | 17 +++ tests/jobs/test_local_task_job.py | 115 +++++++++++++++++- tests/listeners/slow_listener.py | 26 ++++ tests/listeners/very_slow_listener.py | 26 ++++ 7 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 tests/listeners/slow_listener.py create mode 100644 tests/listeners/very_slow_listener.py diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index b1e2881907ce9..ad9b79483dd89 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -341,6 +341,14 @@ core: type: string example: ~ default: "downstream" + task_success_overtime: + description: | + Maximum possible time (in seconds) that task will have for execution of auxiliary processes + (like listeners, mini scheduler...) after task is marked as success.. + version_added: 2.10.0 + type: integer + example: ~ + default: "20" default_task_execution_timeout: description: | The default task execution_timeout value for the operators. Expected an integer value to diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index 16b87f2b818e8..4dd949192e7c8 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -110,6 +110,8 @@ def __init__( self.terminating = False self._state_change_checks = 0 + # time spend after task completed, but before it exited - used to measure listener execution time + self._overtime = 0.0 def _execute(self) -> int | None: from airflow.task.task_runner import get_task_runner @@ -195,7 +197,6 @@ def sigusr2_debug_handler(signum, frame): self.job.heartrate if self.job.heartrate is not None else heartbeat_time_limit, ), ) - return_code = self.task_runner.return_code(timeout=max_wait_time) if return_code is not None: self.handle_task_exit(return_code) @@ -290,6 +291,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: ) raise AirflowException("PID of job runner does not match") elif self.task_runner.return_code() is None and hasattr(self.task_runner, "process"): + self._overtime = (timezone.utcnow() - (ti.end_date or timezone.utcnow())).total_seconds() if ti.state == TaskInstanceState.SKIPPED: # A DagRun timeout will cause tasks to be externally marked as skipped. dagrun = ti.get_dagrun(session=session) @@ -303,6 +305,11 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: if dagrun_timeout and execution_time > dagrun_timeout: self.log.warning("DagRun timed out after %s.", execution_time) + # If process still runs after being marked as success, let it run until configured overtime + if ti.state == TaskInstanceState.SUCCESS and self._overtime < conf.getint( + "core", "task_success_overtime" + ): + return # potential race condition, the _run_raw_task commits `success` or other state # but task_runner does not exit right away due to slow process shutdown or any other reasons # let's do a throttle here, if the above case is true, the handle_task_exit will handle it diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 76b60d61b7119..728159a79524f 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -134,7 +134,6 @@ def on_running(): dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None ) data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None - redacted_event = self.adapter.start_task( run_id=task_uuid, job_name=get_job_name(task), diff --git a/tests/dags/test_mark_state.py b/tests/dags/test_mark_state.py index 331da2d498aca..71e3b0e430049 100644 --- a/tests/dags/test_mark_state.py +++ b/tests/dags/test_mark_state.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import time from datetime import datetime from time import sleep @@ -24,6 +25,7 @@ from airflow.operators.python import PythonOperator from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.timezone import utcnow DEFAULT_DATE = datetime(2016, 1, 1) @@ -41,11 +43,22 @@ def success_callback(context): assert context["dag_run"].dag_id == dag_id +def sleep_execution(): + time.sleep(1) + + +def slow_execution(): + import re + + re.match(r"(a?){30}a{30}", "a" * 30) + + def test_mark_success_no_kill(ti): assert ti.state == State.RUNNING # Simulate marking this successful in the UI with create_session() as session: ti.state = State.SUCCESS + ti.end_date = utcnow() session.merge(ti) session.commit() # The below code will not run as heartbeat will detect change of state @@ -103,3 +116,7 @@ def test_mark_skipped_externally(ti): PythonOperator(task_id="test_mark_skipped_externally", python_callable=test_mark_skipped_externally, dag=dag) PythonOperator(task_id="dummy", python_callable=lambda: True, dag=dag) + +PythonOperator(task_id="slow_execution", python_callable=slow_execution, dag=dag) + +PythonOperator(task_id="sleep_execution", python_callable=sleep_execution, dag=dag) diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 6de4dcd6a47be..1fcacc2a2927f 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -33,11 +33,12 @@ import pytest from airflow import settings -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.job import Job, run_job from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, LocalTaskJobRunner from airflow.jobs.scheduler_job_runner import SchedulerJobRunner +from airflow.listeners.listener import get_listener_manager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.serialized_dag import SerializedDagModel @@ -322,6 +323,7 @@ def test_heartbeat_failed_fast(self): delta = (time2 - time1).total_seconds() assert abs(delta - job.heartrate) < 0.8 + @conf_vars({("core", "task_success_overtime"): "1"}) def test_mark_success_no_kill(self, caplog, get_test_dag, session): """ Test that ensures that mark_success in the UI doesn't cause @@ -553,6 +555,7 @@ def test_failure_callback_called_by_airflow_run_raw_process(self, monkeypatch, t assert m, "pid expected in output." assert os.getpid() != int(m.group(1)) + @conf_vars({("core", "task_success_overtime"): "5"}) def test_mark_success_on_success_callback(self, caplog, get_test_dag): """ Test that ensures that where a task is marked success in the UI @@ -583,6 +586,116 @@ def test_mark_success_on_success_callback(self, caplog, get_test_dag): "State of this instance has been externally set to success. Terminating instance." in caplog.text ) + def test_success_listeners_executed(self, caplog, get_test_dag): + """ + Test that ensures that when listeners are executed, the task is not killed before they finish + or timeout + """ + from tests.listeners import slow_listener + + lm = get_listener_manager() + lm.clear() + lm.add_listener(slow_listener) + + dag = get_test_dag("test_mark_state") + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + ) + task = dag.get_task(task_id="sleep_execution") + + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) + + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + with timeout(30): + run_job(job=job, execute_callable=job_runner._execute) + ti.refresh_from_db() + assert ( + "State of this instance has been externally set to success. Terminating instance." + not in caplog.text + ) + lm.clear() + + def test_success_slow_listeners_executed_kill(self, caplog, get_test_dag): + """ + Test that ensures that when there are too slow listeners, the task is killed + """ + from tests.listeners import very_slow_listener + + lm = get_listener_manager() + lm.clear() + lm.add_listener(very_slow_listener) + + dag = get_test_dag("test_mark_state") + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + ) + task = dag.get_task(task_id="sleep_execution") + + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) + + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + with timeout(30): + run_job(job=job, execute_callable=job_runner._execute) + ti.refresh_from_db() + assert ( + "State of this instance has been externally set to success. Terminating instance." in caplog.text + ) + lm.clear() + + def test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, caplog, get_test_dag): + """ + Test that ensures that when there are listeners, but the task is taking a long time anyways, + it's not killed by the overtime mechanism. + """ + from tests.listeners import slow_listener + + lm = get_listener_manager() + lm.clear() + lm.add_listener(slow_listener) + + dag = get_test_dag("test_mark_state") + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + with create_session() as session: + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + ) + task = dag.get_task(task_id="slow_execution") + + ti = dr.get_task_instance(task.task_id) + ti.refresh_from_task(task) + + job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + with pytest.raises(AirflowTaskTimeout): + with timeout(30): + run_job(job=job, execute_callable=job_runner._execute) + ti.refresh_from_db() + assert ( + "State of this instance has been externally set to success. Terminating instance." + not in caplog.text + ) + lm.clear() + @pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL]) def test_process_os_signal_calls_on_failure_callback( self, monkeypatch, tmp_path, get_test_dag, signal_type diff --git a/tests/listeners/slow_listener.py b/tests/listeners/slow_listener.py new file mode 100644 index 0000000000000..0575e50beb639 --- /dev/null +++ b/tests/listeners/slow_listener.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time + +from airflow.listeners import hookimpl + + +@hookimpl +def on_task_instance_success(previous_state, task_instance, session): + time.sleep(5) diff --git a/tests/listeners/very_slow_listener.py b/tests/listeners/very_slow_listener.py new file mode 100644 index 0000000000000..688faded975de --- /dev/null +++ b/tests/listeners/very_slow_listener.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time + +from airflow.listeners import hookimpl + + +@hookimpl +def on_task_instance_success(previous_state, task_instance, session): + time.sleep(30)