Skip to content

Commit

Permalink
task runner: notify of component start and finish (#27855)
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski authored Nov 24, 2022
1 parent efaabd9 commit 7d79812
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 4 deletions.
21 changes: 17 additions & 4 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
Expand Down Expand Up @@ -313,6 +314,10 @@ def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
root_logger.handlers[:] = orig_handlers


class TaskCommandMarker:
"""Marker for listener hooks, to properly detect from which component they are called."""


@cli_utils.action_cli(check_db=False)
def task_run(args, dag=None):
"""
Expand Down Expand Up @@ -364,6 +369,8 @@ def task_run(args, dag=None):
# processing hundreds of simultaneous tasks.
settings.reconfigure_orm(disable_connection_pool=True)

get_listener_manager().hook.on_starting(component=TaskCommandMarker())

if args.pickle:
print(f"Loading pickle id: {args.pickle}")
dag = get_dag_by_pickle(args.pickle)
Expand All @@ -380,11 +387,17 @@ def task_run(args, dag=None):

log.info("Running %s on host %s", ti, hostname)

if args.interactive:
_run_task_by_selected_method(args, dag, ti)
else:
with _capture_task_logs(ti):
try:
if args.interactive:
_run_task_by_selected_method(args, dag, ti)
else:
with _capture_task_logs(ti):
_run_task_by_selected_method(args, dag, ti)
finally:
try:
get_listener_manager().hook.before_stopping(component=TaskCommandMarker())
except Exception:
pass


@cli_utils.action_cli(check_db=False)
Expand Down
44 changes: 44 additions & 0 deletions tests/listeners/file_write_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging

from airflow.cli.commands.task_command import TaskCommandMarker
from airflow.listeners import hookimpl

log = logging.getLogger(__name__)


class FileWriteListener:
def __init__(self, path):
self.path = path

def write(self, line: str):
with open(self.path, "a") as f:
f.write(line + "\n")

@hookimpl
def on_starting(self, component):
if isinstance(component, TaskCommandMarker):
self.write("on_starting")

@hookimpl
def before_stopping(self, component):
if isinstance(component, TaskCommandMarker):
self.write("before_stopping")
49 changes: 49 additions & 0 deletions tests/task/task_runner/test_standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
Expand All @@ -37,6 +38,7 @@
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from tests.listeners.file_write_listener import FileWriteListener
from tests.test_utils.db import clear_db_runs

TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]
Expand Down Expand Up @@ -111,6 +113,53 @@ def test_start_and_terminate(self):

assert runner.return_code() is not None

def test_notifies_about_start_and_stop(self):
path_listener_writer = "/tmp/path_listener_writer"
try:
os.unlink(path_listener_writer)
except OSError:
pass

lm = get_listener_manager()
lm.add_listener(FileWriteListener(path_listener_writer))

dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
)
dag = dagbag.dags.get("test_example_bash_operator")
task = dag.get_task("runme_1")

with create_session() as session:
dag.create_dagrun(
run_id="test",
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
state=State.RUNNING,
start_date=DEFAULT_DATE,
session=session,
)
ti = TaskInstance(task=task, run_id="test")
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
session.commit()
ti.refresh_from_task(task)

runner = StandardTaskRunner(job1)
runner.start()

# Wait until process sets its pgid to be equal to pid
with timeout(seconds=1):
while True:
runner_pgid = os.getpgid(runner.process.pid)
if runner_pgid == runner.process.pid:
break
time.sleep(0.01)

# Wait till process finishes
assert runner.return_code(timeout=10) is not None
with open(path_listener_writer) as f:
assert f.readline() == "on_starting\n"
assert f.readline() == "before_stopping\n"

def test_start_and_terminate_run_as_user(self):
local_task_job = mock.Mock()
local_task_job.task_instance = mock.MagicMock()
Expand Down

0 comments on commit 7d79812

Please sign in to comment.