Skip to content

Commit

Permalink
openlineage: execute extraction and message sending in separate process
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Jun 14, 2024
1 parent 6f40984 commit 187d87e
Show file tree
Hide file tree
Showing 11 changed files with 391 additions and 142 deletions.
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None:
if not result:
return None
"""
# chain.pop modifies passed list, this can be unexpected
chain = chain.copy()
chain.reverse()
try:
while chain:
while isinstance(source, list) and len(source) == 1:
source = source[0]
next_key = chain.pop()
if isinstance(source, dict):
source = source.get(next_key)
Expand Down
17 changes: 16 additions & 1 deletion airflow/providers/openlineage/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
import os
from typing import Any

from airflow.compat.functools import cache
# Disable caching if we're inside tests - this makes config easier to mock.
if os.getenv("PYTEST_VERSION"):

def decorator(func):
return func

cache = decorator
else:
from airflow.compat.functools import cache
from airflow.configuration import conf

_CONFIG_SECTION = "openlineage"
Expand Down Expand Up @@ -130,3 +138,10 @@ def dag_state_change_process_pool_size() -> int:
"""[openlineage] dag_state_change_process_pool_size."""
option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="")
return _safe_int_convert(str(option).strip(), default=1)


@cache
def execution_timeout() -> int:
"""[openlineage] execution_timeout."""
option = conf.get(_CONFIG_SECTION, "execution_timeout", fallback="")
return _safe_int_convert(str(option).strip(), default=10)
51 changes: 48 additions & 3 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from __future__ import annotations

import logging
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING

import psutil
from openlineage.client.serde import Serde
from packaging.version import Version
from setproctitle import getproctitle, setproctitle

from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow.listeners import hookimpl
Expand All @@ -38,6 +41,7 @@
is_selective_lineage_enabled,
print_warning,
)
from airflow.settings import configure_orm
from airflow.stats import Stats
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -156,7 +160,7 @@ def on_running():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_running()
self._execute(on_running, "on_running", use_fork=True)

@hookimpl
def on_task_instance_success(
Expand Down Expand Up @@ -223,7 +227,7 @@ def on_success():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_success()
self._execute(on_success, "on_success", use_fork=True)

if _IS_AIRFLOW_2_10_OR_HIGHER:

Expand Down Expand Up @@ -318,10 +322,51 @@ def on_failure():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_failure()
self._execute(on_failure, "on_failure", use_fork=True)

def _execute(self, callable, callable_name: str, use_fork: bool = False):
if use_fork:
self._fork_execute(callable, callable_name)
else:
callable()

def _terminate_with_wait(self, process: psutil.Process):
process.terminate()
try:
# Waiting for max 3 seconds to make sure process can clean up before being killed.
process.wait(timeout=3)
except psutil.TimeoutExpired:
# If it's not dead by then, then force kill.
process.kill()

def _fork_execute(self, callable, callable_name: str):
self.log.debug("Will fork to execute OpenLineage process.")
pid = os.fork()
if pid:
process = psutil.Process(pid)
try:
self.log.debug("Waiting for process %s", pid)
process.wait(conf.execution_timeout())
except psutil.TimeoutExpired:
self.log.warning(
"OpenLineage process %s expired. This should not affect process execution.", pid
)
self._terminate_with_wait(process)
except BaseException:
# Kill the process directly.
self._terminate_with_wait(process)
self.log.warning("Process with pid %s finished - parent", pid)
else:
setproctitle(getproctitle() + " - OpenLineage - " + callable_name)
configure_orm(disable_connection_pool=True)
self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid())
callable()
self.log.debug("Process with current pid finishes after %s", callable_name)
os._exit(0)

@property
def executor(self) -> ProcessPoolExecutor:
# Executor for dag_run listener
def initializer():
# Re-configure the ORM engine as there are issues with multiple processes
# if process calls Airflow DB.
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/openlineage/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-common-sql>=1.6.0
- attrs>=22.2
- openlineage-integration-common>=1.15.0
- openlineage-python>=1.15.0
- openlineage-integration-common>=1.16.0
- openlineage-python>=1.16.0

integrations:
- integration-name: OpenLineage
Expand Down Expand Up @@ -144,3 +144,10 @@ config:
example: ~
type: integer
version_added: 1.8.0
execution_timeout:
description: |
Maximum amount of time (in seconds) that OpenLineage can spend executing metadata extraction.
default: "10"
example: ~
type: integer
version_added: 1.9.0
4 changes: 2 additions & 2 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@
"apache-airflow-providers-common-sql>=1.6.0",
"apache-airflow>=2.7.0",
"attrs>=22.2",
"openlineage-integration-common>=1.15.0",
"openlineage-python>=1.15.0"
"openlineage-integration-common>=1.16.0",
"openlineage-python>=1.16.0"
],
"devel-deps": [],
"plugins": [
Expand Down
60 changes: 60 additions & 0 deletions tests/dags/test_openlineage_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
import time

from openlineage.client.generated.base import Dataset

from airflow.models.dag import DAG
from airflow.models.operator import BaseOperator
from airflow.providers.openlineage.extractors import OperatorLineage


class OpenLineageExecutionOperator(BaseOperator):
def __init__(self, *, stall_amount=0, **kwargs) -> None:
super().__init__(**kwargs)
self.stall_amount = stall_amount

def execute(self, context):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(1)

def get_openlineage_facets_on_start(self):
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")])

def get_openlineage_facets_on_complete(self, task_instance):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(self.stall_amount)
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")])


with DAG(
dag_id="test_openlineage_execution",
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
schedule="0 0 * * *",
dagrun_timeout=datetime.timedelta(minutes=60),
):
no_stall = OpenLineageExecutionOperator(task_id="execute_no_stall")

short_stall = OpenLineageExecutionOperator(task_id="execute_short_stall", stall_amount=5)

mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15)

long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30)
42 changes: 0 additions & 42 deletions tests/providers/openlineage/plugins/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,7 @@
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.providers.openlineage.conf import (
config_path,
custom_extractors,
disabled_operators,
is_disabled,
is_source_enabled,
namespace,
transport,
)
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.plugins.adapter import _PRODUCER, OpenLineageAdapter
Expand All @@ -64,27 +58,6 @@
pytestmark = pytest.mark.db_test


@pytest.fixture(autouse=True)
def clear_cache():
config_path.cache_clear()
is_source_enabled.cache_clear()
disabled_operators.cache_clear()
custom_extractors.cache_clear()
namespace.cache_clear()
transport.cache_clear()
is_disabled.cache_clear()
try:
yield
finally:
config_path.cache_clear()
is_source_enabled.cache_clear()
disabled_operators.cache_clear()
custom_extractors.cache_clear()
namespace.cache_clear()
transport.cache_clear()
is_disabled.cache_clear()


@patch.dict(
os.environ,
{"OPENLINEAGE_URL": "http://ol-api:5000", "OPENLINEAGE_API_KEY": "api-key"},
Expand Down Expand Up @@ -155,9 +128,6 @@ def test_create_client_overrides_env_vars():
assert client.transport.kind == "http"
assert client.transport.url == "http://localhost:5050"

transport.cache_clear()
config_path.cache_clear()

with conf_vars({("openlineage", "transport"): '{"type": "console"}'}):
client = OpenLineageAdapter().get_or_create_openlineage_client()

Expand Down Expand Up @@ -893,9 +863,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "random_token"

config_path.cache_clear()
transport.cache_clear()

# Second, check transport in Airflow configuration (airflow.cfg or env variable)
with patch.dict(
os.environ,
Expand All @@ -917,9 +884,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.kafka_config.topic == "test"
assert client.transport.kafka_config.config == {"acks": "all"}

config_path.cache_clear()
transport.cache_clear()

# Third, check legacy OPENLINEAGE_CONFIG env variable
with patch.dict(
os.environ,
Expand All @@ -942,9 +906,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "random_token"

config_path.cache_clear()
transport.cache_clear()

# Fourth, check legacy OPENLINEAGE_URL env variable
with patch.dict(
os.environ,
Expand All @@ -967,9 +928,6 @@ def test_configuration_precedence_when_creating_ol_client():
assert client.transport.config.endpoint == "api/v1/lineage"
assert client.transport.config.auth.api_key == "test_api_key"

config_path.cache_clear()
transport.cache_clear()

# If all else fails, use console transport
with patch.dict(os.environ, {}, clear=True):
with conf_vars(
Expand Down
Loading

0 comments on commit 187d87e

Please sign in to comment.