From 69865d61efae325afc987538d440d4863a45bd59 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Fri, 31 May 2024 14:47:28 +0200 Subject: [PATCH] openlineage: execute extraction and message sending in separate process Signed-off-by: Maciej Obuchowski --- .../google/cloud/openlineage/utils.py | 4 + airflow/providers/openlineage/conf.py | 17 +- .../providers/openlineage/plugins/listener.py | 46 ++++- airflow/providers/openlineage/provider.yaml | 11 +- airflow/providers/openlineage/sqlparser.py | 16 +- airflow/providers/openlineage/utils/sql.py | 6 + .../providers/snowflake/hooks/snowflake.py | 6 +- generated/provider_dependencies.json | 4 +- tests/dags/test_openlineage_execution.py | 60 ++++++ .../openlineage/plugins/test_adapter.py | 42 ---- .../openlineage/plugins/test_execution.py | 195 ++++++++++++++++++ .../openlineage/plugins/test_listener.py | 115 +++++------ .../openlineage/plugins/test_openlineage.py | 9 - tests/providers/openlineage/test_conf.py | 25 --- 14 files changed, 407 insertions(+), 149 deletions(-) create mode 100644 tests/dags/test_openlineage_execution.py create mode 100644 tests/providers/openlineage/plugins/test_execution.py diff --git a/airflow/providers/google/cloud/openlineage/utils.py b/airflow/providers/google/cloud/openlineage/utils.py index 4dc0e4030bd9d9..d7034083edc467 100644 --- a/airflow/providers/google/cloud/openlineage/utils.py +++ b/airflow/providers/google/cloud/openlineage/utils.py @@ -158,9 +158,13 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None: if not result: return None """ + # chain.pop modifies passed list, this can be unexpected + chain = chain.copy() chain.reverse() try: while chain: + while isinstance(source, list) and len(source) == 1: + source = source[0] next_key = chain.pop() if isinstance(source, dict): source = source.get(next_key) diff --git a/airflow/providers/openlineage/conf.py b/airflow/providers/openlineage/conf.py index a9601a416bebf3..76fbd70dd1c136 100644 --- a/airflow/providers/openlineage/conf.py +++ b/airflow/providers/openlineage/conf.py @@ -33,7 +33,15 @@ import os from typing import Any -from airflow.compat.functools import cache +# Disable caching if we're inside tests - this makes config easierg to mock. +if os.getenv("PYTEST_VERSION"): + + def decorator(func): + return func + + cache = decorator +else: + from airflow.compat.functools import cache from airflow.configuration import conf _CONFIG_SECTION = "openlineage" @@ -130,3 +138,10 @@ def dag_state_change_process_pool_size() -> int: """[openlineage] dag_state_change_process_pool_size.""" option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="") return _safe_int_convert(str(option).strip(), default=1) + + +@cache +def execution_timeout() -> int: + """[openlineage] execution_timeout.""" + option = conf.get(_CONFIG_SECTION, "execution_timeout", fallback="") + return _safe_int_convert(str(option).strip(), default=10) diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index 728159a79524fc..6ae850ee9eab80 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -17,12 +17,15 @@ from __future__ import annotations import logging +import os from concurrent.futures import ProcessPoolExecutor from datetime import datetime from typing import TYPE_CHECKING +import psutil from openlineage.client.serde import Serde from packaging.version import Version +from setproctitle import getproctitle, setproctitle from airflow import __version__ as AIRFLOW_VERSION, settings from airflow.listeners import hookimpl @@ -38,6 +41,7 @@ is_selective_lineage_enabled, print_warning, ) +from airflow.settings import configure_orm from airflow.stats import Stats from airflow.utils.timeout import timeout @@ -156,7 +160,7 @@ def on_running(): len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_running() + self._execute(on_running, "on_running", use_fork=True) @hookimpl def on_task_instance_success( @@ -223,7 +227,7 @@ def on_success(): len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_success() + self._execute(on_success, "on_success", use_fork=True) if _IS_AIRFLOW_2_10_OR_HIGHER: @@ -318,10 +322,46 @@ def on_failure(): len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_failure() + self._execute(on_failure, "on_failure", use_fork=True) + + def _execute(self, callable, callable_name: str, use_fork: bool = False): + if use_fork: + self._fork_execute(callable, callable_name) + else: + callable() + + def _fork_execute(self, callable, callable_name: str): + self.log.debug("Will fork to execute OpenLineage process.") + pid = os.fork() + if pid: + process = psutil.Process(pid) + try: + self.log.debug("Waiting for process %s", pid) + process.wait(conf.execution_timeout()) + except psutil.TimeoutExpired: + self.log.warning( + "OpenLineage process %s expired. This should not affect process execution.", pid + ) + process.kill() + except BaseException: + # Kill the process directly. + pass + try: + process.kill() + except Exception: + pass + self.log.warning("Process with pid %s finished - parent", pid) + else: + setproctitle(getproctitle() + " - OpenLineage - " + callable_name) + configure_orm(disable_connection_pool=True) + self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid()) + callable() + self.log.debug("Process with current pid finishes after %s", callable_name) + os._exit(0) @property def executor(self) -> ProcessPoolExecutor: + # Executor for dag_run listener def initializer(): # Re-configure the ORM engine as there are issues with multiple processes # if process calls Airflow DB. diff --git a/airflow/providers/openlineage/provider.yaml b/airflow/providers/openlineage/provider.yaml index de17e15c9f04e5..f63596cdc74ca6 100644 --- a/airflow/providers/openlineage/provider.yaml +++ b/airflow/providers/openlineage/provider.yaml @@ -45,8 +45,8 @@ dependencies: - apache-airflow>=2.7.0 - apache-airflow-providers-common-sql>=1.6.0 - attrs>=22.2 - - openlineage-integration-common>=1.15.0 - - openlineage-python>=1.15.0 + - openlineage-integration-common>=1.16.0 + - openlineage-python>=1.16.0 integrations: - integration-name: OpenLineage @@ -144,3 +144,10 @@ config: example: ~ type: integer version_added: 1.8.0 + execution_timeout: + description: | + Maximum amount of time (in seconds) that OpenLineage can spend executing metadata extraction. + default: "10" + example: ~ + type: integer + version_added: 1.9.0 diff --git a/airflow/providers/openlineage/sqlparser.py b/airflow/providers/openlineage/sqlparser.py index f181ff8ccea019..470b93d3cb9e05 100644 --- a/airflow/providers/openlineage/sqlparser.py +++ b/airflow/providers/openlineage/sqlparser.py @@ -39,6 +39,7 @@ get_table_schemas, ) from airflow.typing_compat import TypedDict +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -116,7 +117,7 @@ def from_table_meta( return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper()) -class SQLParser: +class SQLParser(LoggingMixin): """Interface for openlineage-sql. :param dialect: dialect specific to the database @@ -124,11 +125,18 @@ class SQLParser: """ def __init__(self, dialect: str | None = None, default_schema: str | None = None) -> None: + super().__init__() self.dialect = dialect self.default_schema = default_schema def parse(self, sql: list[str] | str) -> SqlMeta | None: """Parse a single or a list of SQL statements.""" + self.log.debug( + "OpenLineage calling SQL parser with SQL %s dialect %s schema %s", + sql, + self.dialect, + self.default_schema, + ) return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema) def parse_table_schemas( @@ -151,6 +159,7 @@ def parse_table_schemas( "database": database or database_info.database, "use_flat_cross_db_query": database_info.use_flat_cross_db_query, } + self.log.info("PRE getting schemas for input and output tables") return get_table_schemas( hook, namespace, @@ -335,9 +344,8 @@ def split_statement(sql: str) -> list[str]: return split_statement(sql) return [obj for stmt in sql for obj in cls.split_sql_string(stmt) if obj != ""] - @classmethod def create_information_schema_query( - cls, + self, tables: list[DbTableMeta], normalize_name: Callable[[str], str], is_cross_db: bool, @@ -349,7 +357,7 @@ def create_information_schema_query( sqlalchemy_engine: Engine | None = None, ) -> str: """Create SELECT statement to query information schema table.""" - tables_hierarchy = cls._get_tables_hierarchy( + tables_hierarchy = self._get_tables_hierarchy( tables, normalize_name=normalize_name, database=database, diff --git a/airflow/providers/openlineage/utils/sql.py b/airflow/providers/openlineage/utils/sql.py index f959745b9361b5..f5d083b4e46905 100644 --- a/airflow/providers/openlineage/utils/sql.py +++ b/airflow/providers/openlineage/utils/sql.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging from collections import defaultdict from contextlib import closing from enum import IntEnum @@ -33,6 +34,9 @@ from airflow.hooks.base import BaseHook +log = logging.getLogger(__name__) + + class ColumnIndex(IntEnum): """Enumerates the indices of columns in information schema view.""" @@ -90,6 +94,7 @@ def get_table_schemas( if not in_query and not out_query: return [], [] + log.debug("Starting to query database for table schemas") with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor: if in_query: cursor.execute(in_query) @@ -101,6 +106,7 @@ def get_table_schemas( out_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)] else: out_datasets = [] + log.debug("Got table schema query result from database.") return in_datasets, out_datasets diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 978bcf75e1c566..39e17be7b8f859 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -473,10 +473,10 @@ def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser - connection = self.get_connection(getattr(self, self.conn_name_attr)) - namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) - if self.query_ids: + self.log.info("Getting connector to get database info :sadge:") + connection = self.get_connection(getattr(self, self.conn_name_attr)) + namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) return OperatorLineage( run_facets={ "externalQuery": ExternalQueryRunFacet( diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 0f9b6d03c78b11..00f0b556956f50 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -913,8 +913,8 @@ "apache-airflow-providers-common-sql>=1.6.0", "apache-airflow>=2.7.0", "attrs>=22.2", - "openlineage-integration-common>=1.15.0", - "openlineage-python>=1.15.0" + "openlineage-integration-common>=1.16.0", + "openlineage-python>=1.16.0" ], "devel-deps": [], "plugins": [ diff --git a/tests/dags/test_openlineage_execution.py b/tests/dags/test_openlineage_execution.py new file mode 100644 index 00000000000000..29fb65cf754579 --- /dev/null +++ b/tests/dags/test_openlineage_execution.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import time + +from openlineage.client.generated.base import Dataset + +from airflow.models.dag import DAG +from airflow.models.operator import BaseOperator +from airflow.providers.openlineage.extractors import OperatorLineage + + +class OpenLineageExecutionOperator(BaseOperator): + def __init__(self, *, stall_amount=0, **kwargs) -> None: + super().__init__(**kwargs) + self.stall_amount = stall_amount + + def execute(self, context): + self.log.error("STALL AMOUNT %s", self.stall_amount) + time.sleep(1) + + def get_openlineage_facets_on_start(self): + return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")]) + + def get_openlineage_facets_on_complete(self, task_instance): + self.log.error("STALL AMOUNT %s", self.stall_amount) + time.sleep(self.stall_amount) + return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")]) + + +with DAG( + dag_id="test_openlineage_execution", + default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)}, + schedule="0 0 * * *", + dagrun_timeout=datetime.timedelta(minutes=60), +): + no_stall = OpenLineageExecutionOperator(task_id="execute_no_stall") + + short_stall = OpenLineageExecutionOperator(task_id="execute_short_stall", stall_amount=5) + + mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15) + + long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30) diff --git a/tests/providers/openlineage/plugins/test_adapter.py b/tests/providers/openlineage/plugins/test_adapter.py index 0212f1402c340e..1242bb3ef98b0a 100644 --- a/tests/providers/openlineage/plugins/test_adapter.py +++ b/tests/providers/openlineage/plugins/test_adapter.py @@ -44,13 +44,7 @@ from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator from airflow.providers.openlineage.conf import ( - config_path, - custom_extractors, - disabled_operators, - is_disabled, - is_source_enabled, namespace, - transport, ) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.plugins.adapter import _PRODUCER, OpenLineageAdapter @@ -64,27 +58,6 @@ pytestmark = pytest.mark.db_test -@pytest.fixture(autouse=True) -def clear_cache(): - config_path.cache_clear() - is_source_enabled.cache_clear() - disabled_operators.cache_clear() - custom_extractors.cache_clear() - namespace.cache_clear() - transport.cache_clear() - is_disabled.cache_clear() - try: - yield - finally: - config_path.cache_clear() - is_source_enabled.cache_clear() - disabled_operators.cache_clear() - custom_extractors.cache_clear() - namespace.cache_clear() - transport.cache_clear() - is_disabled.cache_clear() - - @patch.dict( os.environ, {"OPENLINEAGE_URL": "http://ol-api:5000", "OPENLINEAGE_API_KEY": "api-key"}, @@ -155,9 +128,6 @@ def test_create_client_overrides_env_vars(): assert client.transport.kind == "http" assert client.transport.url == "http://localhost:5050" - transport.cache_clear() - config_path.cache_clear() - with conf_vars({("openlineage", "transport"): '{"type": "console"}'}): client = OpenLineageAdapter().get_or_create_openlineage_client() @@ -893,9 +863,6 @@ def test_configuration_precedence_when_creating_ol_client(): assert client.transport.config.endpoint == "api/v1/lineage" assert client.transport.config.auth.api_key == "random_token" - config_path.cache_clear() - transport.cache_clear() - # Second, check transport in Airflow configuration (airflow.cfg or env variable) with patch.dict( os.environ, @@ -917,9 +884,6 @@ def test_configuration_precedence_when_creating_ol_client(): assert client.transport.kafka_config.topic == "test" assert client.transport.kafka_config.config == {"acks": "all"} - config_path.cache_clear() - transport.cache_clear() - # Third, check legacy OPENLINEAGE_CONFIG env variable with patch.dict( os.environ, @@ -942,9 +906,6 @@ def test_configuration_precedence_when_creating_ol_client(): assert client.transport.config.endpoint == "api/v1/lineage" assert client.transport.config.auth.api_key == "random_token" - config_path.cache_clear() - transport.cache_clear() - # Fourth, check legacy OPENLINEAGE_URL env variable with patch.dict( os.environ, @@ -967,9 +928,6 @@ def test_configuration_precedence_when_creating_ol_client(): assert client.transport.config.endpoint == "api/v1/lineage" assert client.transport.config.auth.api_key == "test_api_key" - config_path.cache_clear() - transport.cache_clear() - # If all else fails, use console transport with patch.dict(os.environ, {}, clear=True): with conf_vars( diff --git a/tests/providers/openlineage/plugins/test_execution.py b/tests/providers/openlineage/plugins/test_execution.py new file mode 100644 index 00000000000000..f9392afc2a409c --- /dev/null +++ b/tests/providers/openlineage/plugins/test_execution.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +import os +import random +import shutil +import tempfile +from pathlib import Path +from unittest import mock + +import pytest + +from airflow.jobs.job import Job +from airflow.jobs.local_task_job_runner import LocalTaskJobRunner +from airflow.listeners.listener import get_listener_manager +from airflow.models import DagBag, TaskInstance +from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain +from airflow.providers.openlineage.plugins.listener import OpenLineageListener +from airflow.task.task_runner.standard_task_runner import StandardTaskRunner +from airflow.utils import timezone +from airflow.utils.state import State +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS +from tests.test_utils.config import conf_vars + +TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] +DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + +log = logging.getLogger(__name__) + + +def read_file_content(file_path: str) -> str: + with open(file_path) as file: + return file.read() + + +def get_sorted_events(event_dir: str) -> list[str]: + event_paths = [os.path.join(event_dir, event_path) for event_path in sorted(os.listdir(event_dir))] + return [json.loads(read_file_content(event_path)) for event_path in event_paths] + + +def has_value_in_events(events, chain, value): + x = [get_from_nullable_chain(event, chain) for event in events] + log.error(x) + y = [z == value for z in x] + return any(y) + + +with tempfile.TemporaryDirectory(prefix="venv") as tmp_dir: + listener_path = Path(tmp_dir) / "event" + + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+") + @pytest.mark.usefixtures("reset_logging_config") + class TestOpenLineageExecution: + @pytest.fixture(autouse=True) + def clean_listener_manager(self): + get_listener_manager().clear() + yield + get_listener_manager().clear() + + def setup_job(self, task_name, run_id): + dirpath = Path(tmp_dir) + if dirpath.exists(): + shutil.rmtree(dirpath) + dirpath.mkdir(exist_ok=True, parents=True) + lm = get_listener_manager() + lm.add_listener(OpenLineageListener()) + + dagbag = DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + dag = dagbag.dags.get("test_openlineage_execution") + task = dag.get_task(task_name) + + dag.create_dagrun( + run_id=run_id, + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + ti = TaskInstance(task=task, run_id=run_id) + job = Job(id=random.randint(0, 23478197), dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + task_runner = StandardTaskRunner(job_runner) + with mock.patch("airflow.task.task_runner.get_task_runner", return_value=task_runner): + job_runner._execute() + + return task_runner.return_code(timeout=60) + + @pytest.mark.db_test + @conf_vars({("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}'}) + def test_not_stalled_task_emits_proper_lineage(self): + task_name = "execute_no_stall" + run_id = "test1" + self.setup_job(task_name, run_id) + + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert has_value_in_events(events, ["inputs", "name"], "on-complete") + + @conf_vars( + { + ("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}', + ("openlineage", "execution_timeout"): "15", + } + ) + @pytest.mark.db_test + def test_short_stalled_task_emits_proper_lineage(self): + self.setup_job("execute_short_stall", "test_short_stalled_task_emits_proper_lineage") + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert has_value_in_events(events, ["inputs", "name"], "on-complete") + + @conf_vars( + { + ("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}', + ("openlineage", "execution_timeout"): "3", + } + ) + @pytest.mark.db_test + def test_short_stalled_task_extraction_with_low_execution_is_killed_by_ol_timeout(self): + self.setup_job( + "execute_short_stall", + "test_short_stalled_task_extraction_with_low_execution_is_killed_by_ol_timeout", + ) + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert not has_value_in_events(events, ["inputs", "name"], "on-complete") + + @conf_vars({("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}'}) + @pytest.mark.db_test + def test_mid_stalled_task_is_killed_by_ol_timeout(self): + self.setup_job("execute_mid_stall", "test_mid_stalled_task_is_killed_by_openlineage") + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert not has_value_in_events(events, ["inputs", "name"], "on-complete") + + @conf_vars( + { + ("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}', + ("openlineage", "execution_timeout"): "60", + ("core", "task_success_overtime"): "3", + } + ) + @pytest.mark.db_test + def test_long_stalled_task_is_killed_by_listener_overtime_if_ol_timeout_long_enough(self): + dirpath = Path(tmp_dir) + if dirpath.exists(): + shutil.rmtree(dirpath) + dirpath.mkdir(exist_ok=True, parents=True) + lm = get_listener_manager() + lm.add_listener(OpenLineageListener()) + + dagbag = DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + dag = dagbag.dags.get("test_openlineage_execution") + task = dag.get_task("execute_long_stall") + + dag.create_dagrun( + run_id="test_long_stalled_task_is_killed_by_listener_overtime_if_ol_timeout_long_enough", + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + ti = TaskInstance( + task=task, + run_id="test_long_stalled_task_is_killed_by_listener_overtime_if_ol_timeout_long_enough", + ) + job = Job(id="1", dag_id=ti.dag_id) + job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) + job_runner._execute() + + events = get_sorted_events(tmp_dir) + assert has_value_in_events(events, ["inputs", "name"], "on-start") + assert not has_value_in_events(events, ["inputs", "name"], "on-complete") diff --git a/tests/providers/openlineage/plugins/test_listener.py b/tests/providers/openlineage/plugins/test_listener.py index 572a6877f35e73..2fa8216bb444f1 100644 --- a/tests/providers/openlineage/plugins/test_listener.py +++ b/tests/providers/openlineage/plugins/test_listener.py @@ -29,7 +29,6 @@ from airflow.models import DAG, DagRun, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.python import PythonOperator -from airflow.providers.openlineage import conf from airflow.providers.openlineage.plugins.listener import OpenLineageListener from airflow.providers.openlineage.utils.selective_enable import disable_lineage, enable_lineage from airflow.utils.state import State @@ -62,6 +61,10 @@ def render_df(): return pd.DataFrame({"col": [1, 2]}) +def regular_call(self, callable, callable_name, use_fork): + callable() + + @patch("airflow.models.TaskInstance.xcom_push") @patch("airflow.models.BaseOperator.render_template") def test_listener_does_not_change_task_instance(render_mock, xcom_push_mock): @@ -214,6 +217,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_custom_facets") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_adapter_start_task_is_called_with_proper_arguments( mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet, mock_disabled ): @@ -254,6 +258,7 @@ def test_adapter_start_task_is_called_with_proper_arguments( @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mocked_adapter, mock_disabled): """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. @@ -296,6 +301,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): @mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_adapter_complete_task_is_called_with_proper_arguments( mock_get_job_name, mocked_adapter, mock_disabled ): @@ -335,6 +341,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date): ) +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method(): """Tests the OpenLineageListener's response when a task instance is in the running state. @@ -353,6 +360,7 @@ def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_met @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): """Tests the OpenLineageListener's response when a task instance is in the failed state. @@ -375,6 +383,7 @@ def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_meth @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call) def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(mock_adapter): """Tests the OpenLineageListener's response when a task instance is in the success state. @@ -524,13 +533,10 @@ def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_exec """mock ProcessPoolExecutor and check if conf.dag_state_change_process_pool_size is applied to max_workers""" listener = OpenLineageListener() # mock ProcessPoolExecutor class - try: - with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}): - listener.on_dag_run_running(mock.MagicMock(), None) - mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY) - mock_executor.return_value.submit.assert_called_once() - finally: - conf.dag_state_change_process_pool_size.cache_clear() + with conf_vars({("openlineage", "dag_state_change_process_pool_size"): max_workers}): + listener.on_dag_run_running(mock.MagicMock(), None) + mock_executor.assert_called_once_with(max_workers=expected, initializer=mock.ANY) + mock_executor.return_value.submit.assert_called_once() class TestOpenLineageSelectiveEnable: @@ -570,7 +576,6 @@ def test_listener_with_dag_enabled(self, selective_enable, enable_dag, expected_ if enable_dag: enable_lineage(self.dag) - conf.selective_enable.cache_clear() with conf_vars({("openlineage", "selective_enable"): selective_enable}): listener = OpenLineageListener() listener._executor = mock.Mock() @@ -580,11 +585,6 @@ def test_listener_with_dag_enabled(self, selective_enable, enable_dag, expected_ listener.on_dag_run_failed(self.dagrun, msg="test failure") listener.on_dag_run_success(self.dagrun, msg="test success") - try: - assert expected_call_count == listener._executor.submit.call_count - finally: - conf.selective_enable.cache_clear() - @pytest.mark.parametrize( "selective_enable, enable_task, expected_dag_call_count, expected_task_call_count", [ @@ -594,6 +594,9 @@ def test_listener_with_dag_enabled(self, selective_enable, enable_dag, expected_ ("False", False, 3, 3), ], ) + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) def test_listener_with_task_enabled( self, selective_enable, enable_task, expected_dag_call_count, expected_task_call_count ): @@ -604,49 +607,46 @@ def test_listener_with_task_enabled( on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} - conf.selective_enable.cache_clear() with conf_vars({("openlineage", "selective_enable"): selective_enable}): listener = OpenLineageListener() listener._executor = mock.Mock() listener.extractor_manager = mock.Mock() listener.adapter = mock.Mock() - try: - # run all three DagRun-related hooks - listener.on_dag_run_running(self.dagrun, msg="test running") - listener.on_dag_run_failed(self.dagrun, msg="test failure") - listener.on_dag_run_success(self.dagrun, msg="test success") - - assert expected_dag_call_count == listener._executor.submit.call_count - - # run TaskInstance-related hooks for lineage enabled task - listener.on_task_instance_running(None, self.task_instance_1, None) - listener.on_task_instance_success(None, self.task_instance_1, None) - listener.on_task_instance_failed( - previous_state=None, - task_instance=self.task_instance_1, - session=None, - **on_task_failed_kwargs, - ) - - assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count - - # run TaskInstance-related hooks for lineage disabled task - listener.on_task_instance_running(None, self.task_instance_2, None) - listener.on_task_instance_success(None, self.task_instance_2, None) - listener.on_task_instance_failed( - previous_state=None, - task_instance=self.task_instance_2, - session=None, - **on_task_failed_kwargs, - ) - - # with selective-enable disabled both task_1 and task_2 should trigger metadata extraction - if selective_enable == "False": - expected_task_call_count *= 2 - - assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count - finally: - conf.selective_enable.cache_clear() + + # run all three DagRun-related hooks + listener.on_dag_run_running(self.dagrun, msg="test running") + listener.on_dag_run_failed(self.dagrun, msg="test failure") + listener.on_dag_run_success(self.dagrun, msg="test success") + + assert expected_dag_call_count == listener._executor.submit.call_count + + # run TaskInstance-related hooks for lineage enabled task + listener.on_task_instance_running(None, self.task_instance_1, None) + listener.on_task_instance_success(None, self.task_instance_1, None) + listener.on_task_instance_failed( + previous_state=None, + task_instance=self.task_instance_1, + session=None, + **on_task_failed_kwargs, + ) + + assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count + + # run TaskInstance-related hooks for lineage disabled task + listener.on_task_instance_running(None, self.task_instance_2, None) + listener.on_task_instance_success(None, self.task_instance_2, None) + listener.on_task_instance_failed( + previous_state=None, + task_instance=self.task_instance_2, + session=None, + **on_task_failed_kwargs, + ) + + # with selective-enable disabled both task_1 and task_2 should trigger metadata extraction + if selective_enable == "False": + expected_task_call_count *= 2 + + assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count @pytest.mark.parametrize( "selective_enable, enable_task, expected_call_count, expected_task_call_count", @@ -657,6 +657,9 @@ def test_listener_with_task_enabled( ("False", False, 3, 3), ], ) + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call + ) def test_listener_with_dag_disabled_task_enabled( self, selective_enable, enable_task, expected_call_count, expected_task_call_count ): @@ -668,7 +671,6 @@ def test_listener_with_dag_disabled_task_enabled( on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} - conf.selective_enable.cache_clear() with conf_vars({("openlineage", "selective_enable"): selective_enable}): listener = OpenLineageListener() listener._executor = mock.Mock() @@ -687,8 +689,5 @@ def test_listener_with_dag_disabled_task_enabled( previous_state=None, task_instance=self.task_instance_1, session=None, **on_task_failed_kwargs ) - try: - assert expected_call_count == listener._executor.submit.call_count - assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count - finally: - conf.selective_enable.cache_clear() + assert expected_call_count == listener._executor.submit.call_count + assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count diff --git a/tests/providers/openlineage/plugins/test_openlineage.py b/tests/providers/openlineage/plugins/test_openlineage.py index e736b3ee9e57a4..dcb8198ceccad8 100644 --- a/tests/providers/openlineage/plugins/test_openlineage.py +++ b/tests/providers/openlineage/plugins/test_openlineage.py @@ -23,7 +23,6 @@ import pytest -from airflow.providers.openlineage.conf import config_path, is_disabled, transport from tests.conftest import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES from tests.test_utils.config import conf_vars @@ -33,19 +32,11 @@ ) class TestOpenLineageProviderPlugin: def setup_method(self): - is_disabled.cache_clear() - transport.cache_clear() - config_path.cache_clear() # Remove module under test if loaded already before. This lets us # import the same source files for more than one test. if "airflow.providers.openlineage.plugins.openlineage" in sys.modules: del sys.modules["airflow.providers.openlineage.plugins.openlineage"] - def teardown_method(self): - is_disabled.cache_clear() - transport.cache_clear() - config_path.cache_clear() - @pytest.mark.parametrize( "mocks, expected", [ diff --git a/tests/providers/openlineage/test_conf.py b/tests/providers/openlineage/test_conf.py index f52d8453acc2c9..60060b001c6d62 100644 --- a/tests/providers/openlineage/test_conf.py +++ b/tests/providers/openlineage/test_conf.py @@ -69,31 +69,6 @@ ) -@pytest.fixture(autouse=True) -def clear_cache(): - config_path.cache_clear() - is_source_enabled.cache_clear() - disabled_operators.cache_clear() - custom_extractors.cache_clear() - namespace.cache_clear() - transport.cache_clear() - is_disabled.cache_clear() - selective_enable.cache_clear() - dag_state_change_process_pool_size.cache_clear() - try: - yield - finally: - config_path.cache_clear() - is_source_enabled.cache_clear() - disabled_operators.cache_clear() - custom_extractors.cache_clear() - namespace.cache_clear() - transport.cache_clear() - is_disabled.cache_clear() - selective_enable.cache_clear() - dag_state_change_process_pool_size.cache_clear() - - @pytest.mark.parametrize( ("var_string", "expected"), (