diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 1fe60bd773733..da175866a03ef 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -29,6 +29,7 @@ from openlineage.client.facet import BaseFacet as BaseFacet_V1 from openlineage.client.facet_v2 import JobFacet, RunFacet +from airflow.providers.openlineage.utils.utils import IS_AIRFLOW_2_10_OR_HIGHER from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState @@ -115,7 +116,14 @@ def _execute_extraction(self) -> OperatorLineage | None: return None def extract_on_complete(self, task_instance) -> OperatorLineage | None: - if task_instance.state == TaskInstanceState.FAILED: + failed_states = [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY] + if not IS_AIRFLOW_2_10_OR_HIGHER: # todo: remove when min airflow version >= 2.10.0 + # Before fix (#41053) implemented in Airflow 2.10 TaskInstance's state was still RUNNING when + # being passed to listener's on_failure method. Since `extract_on_complete()` is only called + # after task completion, RUNNING state means that we are dealing with FAILED task in < 2.10 + failed_states = [TaskInstanceState.RUNNING] + + if task_instance.state in failed_states: on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) if on_failed and callable(on_failed): self.log.debug( diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 2f227fa2e689e..9da9267db3b28 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -23,15 +23,15 @@ import psutil from openlineage.client.serde import Serde -from packaging.version import Version from setproctitle import getproctitle, setproctitle -from airflow import __version__ as AIRFLOW_VERSION, settings +from airflow import settings from airflow.listeners import hookimpl from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import ExtractorManager from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState from airflow.providers.openlineage.utils.utils import ( + IS_AIRFLOW_2_10_OR_HIGHER, get_airflow_job_facet, get_airflow_mapped_task_facet, get_airflow_run_facet, @@ -53,12 +53,11 @@ from airflow.models import DagRun, TaskInstance _openlineage_listener: OpenLineageListener | None = None -_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") def _get_try_number_success(val): # todo: remove when min airflow version >= 2.10.0 - if _IS_AIRFLOW_2_10_OR_HIGHER: + if IS_AIRFLOW_2_10_OR_HIGHER: return val.try_number return val.try_number - 1 @@ -247,7 +246,7 @@ def on_success(): self._execute(on_success, "on_success", use_fork=True) - if _IS_AIRFLOW_2_10_OR_HIGHER: + if IS_AIRFLOW_2_10_OR_HIGHER: @hookimpl def on_task_instance_failed( diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 17eb522a6d1ba..3c5ede2c7cb8c 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -64,7 +64,7 @@ log = logging.getLogger(__name__) _NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" -_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") +IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") def try_import_from_string(string: str) -> Any: @@ -663,7 +663,7 @@ def normalize_sql(sql: str | Iterable[str]): def should_use_external_connection(hook) -> bool: # If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again. - if not _IS_AIRFLOW_2_10_OR_HIGHER: + if not IS_AIRFLOW_2_10_OR_HIGHER: return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"] return True diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index d813b7a7e12e3..d415f19ed4efd 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -82,7 +82,7 @@ class TestRedshiftSQLOpenLineage: "airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock, ) - @patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock) + @patch("airflow.providers.openlineage.utils.utils.IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock) @patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") def test_execute_openlineage_events( self, diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index 2f847140070a9..ee242d6dfc3aa 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -25,6 +25,7 @@ from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job from airflow.models.baseoperator import BaseOperator +from airflow.models.taskinstance import TaskInstanceState from airflow.operators.python import PythonOperator from airflow.providers.openlineage.extractors.base import ( BaseExtractor, @@ -233,6 +234,47 @@ def test_extraction_without_on_start(): ) +@pytest.mark.parametrize( + "task_state, is_airflow_2_10_or_higher, should_call_on_failure", + ( + # Airflow >= 2.10 + (TaskInstanceState.FAILED, True, True), + (TaskInstanceState.UP_FOR_RETRY, True, True), + (TaskInstanceState.RUNNING, True, False), + (TaskInstanceState.SUCCESS, True, False), + # Airflow < 2.10 + (TaskInstanceState.RUNNING, False, True), + (TaskInstanceState.SUCCESS, False, False), + (TaskInstanceState.FAILED, False, False), # should never happen, fixed in #41053 + (TaskInstanceState.UP_FOR_RETRY, False, False), # should never happen, fixed in #41053 + ), +) +def test_extract_on_failure(task_state, is_airflow_2_10_or_higher, should_call_on_failure): + task_instance = mock.Mock(state=task_state) + operator = mock.Mock() + operator.get_openlineage_facets_on_failure = mock.Mock( + return_value=OperatorLineage(run_facets={"failed": True}) + ) + operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None) + + extractor = DefaultExtractor(operator=operator) + + with mock.patch( + "airflow.providers.openlineage.extractors.base.IS_AIRFLOW_2_10_OR_HIGHER", is_airflow_2_10_or_higher + ): + result = extractor.extract_on_complete(task_instance) + + if should_call_on_failure: + operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance) + operator.get_openlineage_facets_on_complete.assert_not_called() + assert isinstance(result, OperatorLineage) + assert result.run_facets == {"failed": True} + else: + operator.get_openlineage_facets_on_failure.assert_not_called() + operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance) + assert result is None + + @mock.patch("airflow.providers.openlineage.conf.custom_extractors") def test_extractors_env_var(custom_extractors): custom_extractors.return_value = {"tests.providers.openlineage.extractors.test_base.ExampleExtractor"}