Skip to content

Commit

Permalink
listener: simplify API by replacing SQLAlchemy event-listening by dir…
Browse files Browse the repository at this point in the history
…ect calls (#29289)
  • Loading branch information
mobuchowski authored Feb 3, 2023
1 parent edc2e0b commit 624520d
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 108 deletions.
9 changes: 0 additions & 9 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners.events import register_task_instance_state_events
from airflow.listeners.listener import get_listener_manager
from airflow.models.taskinstance import TaskInstance
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
Expand Down Expand Up @@ -107,7 +105,6 @@ def __init__(
super().__init__(*args, **kwargs)

def _execute(self):
self._enable_task_listeners()
self.task_runner = get_task_runner(self)

def signal_handler(signum, frame):
Expand Down Expand Up @@ -286,9 +283,3 @@ def _log_return_code_metric(self, return_code: int):
Stats.incr(
f"local_task_job.task_exit.{self.id}.{self.dag_id}.{self.task_instance.task_id}.{return_code}"
)

@staticmethod
def _enable_task_listeners():
"""Check for registered listeners, then register sqlalchemy hooks for TI state changes."""
if get_listener_manager().has_listeners:
register_task_instance_state_events()
85 changes: 0 additions & 85 deletions airflow/listeners/events.py

This file was deleted.

15 changes: 15 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
Expand Down Expand Up @@ -1384,6 +1385,12 @@ def _run_raw_task(

self.task = self.task.prepare_for_execution()
context = self.get_template_context(ignore_param_exceptions=False)

# We lose previous state because it's changed in other process in LocalTaskJob.
# We could probably pass it through here though...
get_listener_manager().hook.on_task_instance_running(
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
)
try:
if not mark_success:
self._execute_task_with_callbacks(context, test_mode)
Expand Down Expand Up @@ -1462,6 +1469,10 @@ def _run_raw_task(
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(session=session)
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
)

session.commit()

def _register_dataset_changes(self, *, session: Session) -> None:
Expand Down Expand Up @@ -1792,6 +1803,10 @@ def handle_failure(
if test_mode is None:
test_mode = self.test_mode

get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session
)

if error:
if isinstance(error, BaseException):
tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
Expand Down
35 changes: 35 additions & 0 deletions tests/dags/test_failing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime

from airflow.models import DAG
from airflow.operators.bash import BashOperator

dag = DAG(
dag_id="test_failing_bash_operator",
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
schedule="0 0 * * *",
dagrun_timeout=datetime.timedelta(minutes=60),
)

task = BashOperator(task_id="failing_task", bash_command="sleep 1 && exit 1", dag=dag)

if __name__ == "__main__":
dag.cli()
4 changes: 4 additions & 0 deletions tests/listeners/class_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ def on_task_instance_success(self, previous_state, task_instance, session):
@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, session):
self.state.append(State.FAILED)


def clear():
pass
4 changes: 4 additions & 0 deletions tests/listeners/empty_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
@hookimpl
def on_task_instance_running(previous_state, task_instance, session):
pass


def clear():
pass
16 changes: 16 additions & 0 deletions tests/listeners/file_write_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def write(self, line: str):
with open(self.path, "a") as f:
f.write(line + "\n")

@hookimpl
def on_task_instance_running(self, previous_state, task_instance, session):
self.write("on_task_instance_running")

@hookimpl
def on_task_instance_success(self, previous_state, task_instance, session):
self.write("on_task_instance_success")

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, session):
self.write("on_task_instance_failed")

@hookimpl
def on_starting(self, component):
if isinstance(component, TaskCommandMarker):
Expand All @@ -42,3 +54,7 @@ def on_starting(self, component):
def before_stopping(self, component):
if isinstance(component, TaskCommandMarker):
self.write("before_stopping")


def clear():
pass
4 changes: 4 additions & 0 deletions tests/listeners/partial_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@
@hookimpl
def on_task_instance_running(previous_state, task_instance, session):
state.append(State.RUNNING)


def clear():
pass
38 changes: 26 additions & 12 deletions tests/listeners/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from airflow import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners import events
from airflow.listeners.listener import get_listener_manager
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
Expand All @@ -35,27 +34,28 @@
throwing_listener,
)

LISTENERS = [
class_listener,
full_listener,
lifecycle_listener,
partial_listener,
throwing_listener,
]

DAG_ID = "test_listener_dag"
TASK_ID = "test_listener_task"
EXECUTION_DATE = timezone.utcnow()


@pytest.fixture(scope="module", autouse=True)
def register_events():
events.register_task_instance_state_events()
yield
events.unregister_task_instance_state_events()


@pytest.fixture(autouse=True)
def clean_listener_manager():
lm = get_listener_manager()
lm.clear()
yield
lm = get_listener_manager()
lm.clear()
full_listener.clear()
lifecycle_listener.clear()
for listener in LISTENERS:
listener.clear()


@provide_session
Expand Down Expand Up @@ -127,8 +127,22 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator
with pytest.raises(AirflowException):
ti._run_raw_task()

assert full_listener.state == [State.FAILED]
assert len(full_listener.state) == 1
assert full_listener.state == [State.RUNNING, State.FAILED]
assert len(full_listener.state) == 2


@provide_session
def test_listener_captures_longrunning_taskinstances(create_task_instance_of_operator, session=None):
lm = get_listener_manager()
lm.add_listener(full_listener)

ti = create_task_instance_of_operator(
BashOperator, dag_id=DAG_ID, execution_date=EXECUTION_DATE, task_id=TASK_ID, bash_command="sleep 5"
)
ti._run_raw_task()

assert full_listener.state == [State.RUNNING, State.SUCCESS]
assert len(full_listener.state) == 2


@provide_session
Expand Down
4 changes: 4 additions & 0 deletions tests/listeners/throwing_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
@hookimpl
def on_task_instance_success(previous_state, task_instance, session):
raise RuntimeError()


def clear():
pass
49 changes: 47 additions & 2 deletions tests/task/task_runner/test_standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_start_and_terminate(self, mock_init):
assert runner.return_code() is not None

def test_notifies_about_start_and_stop(self):
path_listener_writer = "/tmp/path_listener_writer"
path_listener_writer = "/tmp/test_notifies_about_start_and_stop"
try:
os.unlink(path_listener_writer)
except OSError:
Expand All @@ -152,7 +152,7 @@ def test_notifies_about_start_and_stop(self):
runner = StandardTaskRunner(job1)
runner.start()

# Wait until process makes itself the leader of it's own process group
# Wait until process makes itself the leader of its own process group
with timeout(seconds=1):
while True:
runner_pgid = os.getpgid(runner.process.pid)
Expand All @@ -164,6 +164,51 @@ def test_notifies_about_start_and_stop(self):
assert runner.return_code(timeout=10) is not None
with open(path_listener_writer) as f:
assert f.readline() == "on_starting\n"
assert f.readline() == "on_task_instance_running\n"
assert f.readline() == "on_task_instance_success\n"
assert f.readline() == "before_stopping\n"

def test_notifies_about_fail(self):
path_listener_writer = "/tmp/test_notifies_about_fail"
try:
os.unlink(path_listener_writer)
except OSError:
pass

lm = get_listener_manager()
lm.add_listener(FileWriteListener(path_listener_writer))

dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
)
dag = dagbag.dags.get("test_failing_bash_operator")
task = dag.get_task("failing_task")
dag.create_dagrun(
run_id="test",
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
state=State.RUNNING,
start_date=DEFAULT_DATE,
)
ti = TaskInstance(task=task, run_id="test")
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
runner = StandardTaskRunner(job1)
runner.start()

# Wait until process makes itself the leader of its own process group
with timeout(seconds=1):
while True:
runner_pgid = os.getpgid(runner.process.pid)
if runner_pgid == runner.process.pid:
break
time.sleep(0.01)

# Wait till process finishes
assert runner.return_code(timeout=10) is not None
with open(path_listener_writer) as f:
assert f.readline() == "on_starting\n"
assert f.readline() == "on_task_instance_running\n"
assert f.readline() == "on_task_instance_failed\n"
assert f.readline() == "before_stopping\n"

@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
Expand Down

0 comments on commit 624520d

Please sign in to comment.