Skip to content

Commit

Permalink
make OpenLineage provider support Airflow 3's listener interface
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Jan 8, 2025
1 parent c35dedf commit 0b6cd12
Show file tree
Hide file tree
Showing 7 changed files with 1,490 additions and 672 deletions.
20 changes: 11 additions & 9 deletions docs/apache-airflow/administration-and-deployment/listeners.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ For example if you want to implement a listener that uses the ``error`` field in
...
@hookimpl
def on_task_instance_failed(
self, previous_state, task_instance, error: None | str | BaseException, session
):
def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException):
# Handle error case here
pass
Expand All @@ -177,15 +175,19 @@ For example if you want to implement a listener that uses the ``error`` field in
...
@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, session):
def on_task_instance_failed(self, previous_state, task_instance):
# Handle no error case here
pass
List of changes in the listener interfaces since 2.8.0 when they were introduced:


+-----------------+-----------------------------+---------------------------------------+
| Airflow Version | Affected method | Change |
+=================+=============================+=======================================+
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
+-----------------+-----------------------------+---------------------------------------+
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
| Airflow Version | Affected method | Change |
+=================+============================================+=========================================================================+
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
| 3.0.0 | ``on_task_instance_running``, | ``session`` argument removed from task instance listeners, |
| | ``on_task_instance_success``, | ``task_instance`` object is now an instance of ``RuntimeTaskInstance`` |
| | ``on_task_instance_failed`` | |
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
224 changes: 138 additions & 86 deletions providers/src/airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING

import psutil
Expand All @@ -33,6 +34,7 @@
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState
from airflow.providers.openlineage.utils.utils import (
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
get_airflow_dag_run_facet,
get_airflow_debug_facet,
get_airflow_job_facet,
Expand All @@ -42,7 +44,6 @@
get_user_provided_run_facets,
is_operator_disabled,
is_selective_lineage_enabled,
is_ti_rescheduled_already,
print_warning,
)
from airflow.settings import configure_orm
Expand All @@ -52,9 +53,9 @@
from airflow.utils.timeout import timeout

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.settings import Session

_openlineage_listener: OpenLineageListener | None = None

Expand Down Expand Up @@ -87,28 +88,58 @@ def __init__(self):
self.extractor_manager = ExtractorManager()
self.adapter = OpenLineageAdapter()

@hookimpl
def on_task_instance_running(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session, # This will always be QUEUED
) -> None:
if not getattr(task_instance, "task", None) is not None:
self.log.warning(
"No task set for TI object task_id: %s - dag_id: %s - run_id %s",
task_instance.task_id,
task_instance.dag_id,
task_instance.run_id,
)
return
if AIRFLOW_V_3_0_PLUS:

self.log.debug("OpenLineage listener got notification about task instance start")
dagrun = task_instance.dag_run
task = task_instance.task
if TYPE_CHECKING:
assert task
dag = task.dag
@hookimpl
def on_task_instance_running(
self,
previous_state: TaskInstanceState,
task_instance: RuntimeTaskInstance,
):
if not getattr(task_instance, "task", None) is not None:
self.log.warning(
"No task set for TI object task_id: %s - dag_id: %s - run_id %s",
task_instance.task_id,
task_instance.dag_id,
task_instance.run_id,
)
return

self.log.debug("OpenLineage listener got notification about task instance start")
context = task_instance.get_template_context()

task = context["task"]
if TYPE_CHECKING:
assert task
dagrun = context["dag_run"]
dag = context["dag"]
self._on_task_instance_running(task_instance, dag, dagrun, task)
else:

@hookimpl
def on_task_instance_running(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session, # type: ignore[valid-type]
) -> None:
if not getattr(task_instance, "task", None) is not None:
self.log.warning(
"No task set for TI object task_id: %s - dag_id: %s - run_id %s",
task_instance.task_id,
task_instance.dag_id,
task_instance.run_id,
)
return

self.log.debug("OpenLineage listener got notification about task instance start")
task = task_instance.task
if TYPE_CHECKING:
assert task

self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task)

def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task):
if is_operator_disabled(task):
self.log.debug(
"Skipping OpenLineage event emission for operator `%s` "
Expand All @@ -127,35 +158,34 @@ def on_task_instance_running(
return

# Needs to be calculated outside of inner method so that it gets cached for usage in fork processes
data_interval_start = dagrun.data_interval_start
if isinstance(data_interval_start, datetime):
data_interval_start = data_interval_start.isoformat()
data_interval_end = dagrun.data_interval_end
if isinstance(data_interval_end, datetime):
data_interval_end = data_interval_end.isoformat()

debug_facet = get_airflow_debug_facet()

@print_warning(self.log)
def on_running():
# that's a workaround to detect task running from deferred state
# we return here because Airflow 2.3 needs task from deferred state
if task_instance.next_method is not None:
return

if is_ti_rescheduled_already(task_instance):
context = task_instance.get_template_context()
if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0:
self.log.debug("Skipping this instance of rescheduled task - START event was emitted already")
return

parent_run_id = self.adapter.build_dag_run_id(
dag_id=dag.dag_id,
logical_date=dagrun.logical_date,
clear_number=dagrun.clear_number,
clear_number=0,
)

if hasattr(task_instance, "logical_date"):
logical_date = task_instance.logical_date
else:
logical_date = task_instance.execution_date
start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow()

task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=task_instance.try_number,
logical_date=logical_date,
logical_date=dagrun.logical_date,
map_index=task_instance.map_index,
)
event_type = RunState.RUNNING.value.lower()
Expand All @@ -164,11 +194,6 @@ def on_running():
with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(dagrun, task)

start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow()
data_interval_start = (
dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None
)
data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None
redacted_event = self.adapter.start_task(
run_id=task_uuid,
job_name=get_job_name(task),
Expand All @@ -195,17 +220,39 @@ def on_running():

self._execute(on_running, "on_running", use_fork=True)

@hookimpl
def on_task_instance_success(
self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
) -> None:
self.log.debug("OpenLineage listener got notification about task instance success")
if AIRFLOW_V_3_0_PLUS:

@hookimpl
def on_task_instance_success(
self, previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance
) -> None:
self.log.debug("OpenLineage listener got notification about task instance success")

context = task_instance.get_template_context()
task = context["task"]
if TYPE_CHECKING:
assert task
dagrun = context["dag_run"]
dag = context["dag"]
self._on_task_instance_success(task_instance, dag, dagrun, task)

dagrun = task_instance.dag_run
task = task_instance.task
if TYPE_CHECKING:
assert task
dag = task.dag
else:

@hookimpl
def on_task_instance_success(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session, # type: ignore[valid-type]
) -> None:
self.log.debug("OpenLineage listener got notification about task instance success")
task = task_instance.task
if TYPE_CHECKING:
assert task
self._on_task_instance_success(task_instance, task.dag, task_instance.dag_run, task)

def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dagrun, task):
end_date = timezone.utcnow()

if is_operator_disabled(task):
self.log.debug(
Expand All @@ -232,15 +279,11 @@ def on_success():
clear_number=dagrun.clear_number,
)

if hasattr(task_instance, "logical_date"):
logical_date = task_instance.logical_date
else:
logical_date = task_instance.execution_date
task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=_get_try_number_success(task_instance),
logical_date=logical_date,
logical_date=dagrun.logical_date,
map_index=task_instance.map_index,
)
event_type = RunState.COMPLETE.value.lower()
Expand All @@ -251,8 +294,6 @@ def on_success():
dagrun, task, complete=True, task_instance=task_instance
)

end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow()

redacted_event = self.adapter.complete_task(
run_id=task_uuid,
job_name=get_job_name(task),
Expand All @@ -273,44 +314,62 @@ def on_success():

self._execute(on_success, "on_success", use_fork=True)

if AIRFLOW_V_2_10_PLUS:
if AIRFLOW_V_3_0_PLUS:

@hookimpl
def on_task_instance_failed(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
error: None | str | BaseException,
session: Session,
) -> None:
self._on_task_instance_failed(
previous_state=previous_state, task_instance=task_instance, error=error, session=session
)
self.log.debug("OpenLineage listener got notification about task instance failure")
context = task_instance.get_template_context()
task = context["task"]
if TYPE_CHECKING:
assert task
dagrun = context["dag_run"]
dag = context["dag"]
self._on_task_instance_failed(task_instance, dag, dagrun, task, error)

elif AIRFLOW_V_2_10_PLUS:

@hookimpl
def on_task_instance_failed(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
error: None | str | BaseException,
session: Session, # type: ignore[valid-type]
) -> None:
self.log.debug("OpenLineage listener got notification about task instance failure")
task = task_instance.task
if TYPE_CHECKING:
assert task
self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task, error)
else:

@hookimpl
def on_task_instance_failed(
self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session, # type: ignore[valid-type]
) -> None:
self._on_task_instance_failed(
previous_state=previous_state, task_instance=task_instance, error=None, session=session
)
task = task_instance.task
if TYPE_CHECKING:
assert task
self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task)

def _on_task_instance_failed(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session,
task_instance: TaskInstance | RuntimeTaskInstance,
dag,
dagrun,
task,
error: None | str | BaseException = None,
) -> None:
self.log.debug("OpenLineage listener got notification about task instance failure")

dagrun = task_instance.dag_run
task = task_instance.task
if TYPE_CHECKING:
assert task
dag = task.dag
end_date = timezone.utcnow()

if is_operator_disabled(task):
self.log.debug(
Expand All @@ -337,16 +396,11 @@ def on_failure():
clear_number=dagrun.clear_number,
)

if hasattr(task_instance, "logical_date"):
logical_date = task_instance.logical_date
else:
logical_date = task_instance.execution_date

task_uuid = self.adapter.build_task_instance_run_id(
dag_id=dag.dag_id,
task_id=task.task_id,
try_number=task_instance.try_number,
logical_date=logical_date,
logical_date=dagrun.logical_date,
map_index=task_instance.map_index,
)
event_type = RunState.FAIL.value.lower()
Expand All @@ -357,8 +411,6 @@ def on_failure():
dagrun, task, complete=True, task_instance=task_instance
)

end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow()

redacted_event = self.adapter.fail_task(
run_id=task_uuid,
job_name=get_job_name(task),
Expand Down
Loading

0 comments on commit 0b6cd12

Please sign in to comment.