Skip to content

Commit

Permalink
openlineage: execute extraction and message sending in separate process
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Jun 11, 2024
1 parent 3d4661d commit 69865d6
Show file tree
Hide file tree
Showing 14 changed files with 407 additions and 149 deletions.
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None:
if not result:
return None
"""
# chain.pop modifies passed list, this can be unexpected
chain = chain.copy()
chain.reverse()
try:
while chain:
while isinstance(source, list) and len(source) == 1:
source = source[0]
next_key = chain.pop()
if isinstance(source, dict):
source = source.get(next_key)
Expand Down
17 changes: 16 additions & 1 deletion airflow/providers/openlineage/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@
import os
from typing import Any

from airflow.compat.functools import cache
# Disable caching if we're inside tests - this makes config easierg to mock.
if os.getenv("PYTEST_VERSION"):

def decorator(func):
return func

cache = decorator
else:
from airflow.compat.functools import cache
from airflow.configuration import conf

_CONFIG_SECTION = "openlineage"
Expand Down Expand Up @@ -130,3 +138,10 @@ def dag_state_change_process_pool_size() -> int:
"""[openlineage] dag_state_change_process_pool_size."""
option = conf.get(_CONFIG_SECTION, "dag_state_change_process_pool_size", fallback="")
return _safe_int_convert(str(option).strip(), default=1)


@cache
def execution_timeout() -> int:
"""[openlineage] execution_timeout."""
option = conf.get(_CONFIG_SECTION, "execution_timeout", fallback="")
return _safe_int_convert(str(option).strip(), default=10)
46 changes: 43 additions & 3 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from __future__ import annotations

import logging
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING

import psutil
from openlineage.client.serde import Serde
from packaging.version import Version
from setproctitle import getproctitle, setproctitle

from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow.listeners import hookimpl
Expand All @@ -38,6 +41,7 @@
is_selective_lineage_enabled,
print_warning,
)
from airflow.settings import configure_orm
from airflow.stats import Stats
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -156,7 +160,7 @@ def on_running():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_running()
self._execute(on_running, "on_running", use_fork=True)

@hookimpl
def on_task_instance_success(
Expand Down Expand Up @@ -223,7 +227,7 @@ def on_success():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_success()
self._execute(on_success, "on_success", use_fork=True)

if _IS_AIRFLOW_2_10_OR_HIGHER:

Expand Down Expand Up @@ -318,10 +322,46 @@ def on_failure():
len(Serde.to_json(redacted_event).encode("utf-8")),
)

on_failure()
self._execute(on_failure, "on_failure", use_fork=True)

def _execute(self, callable, callable_name: str, use_fork: bool = False):
if use_fork:
self._fork_execute(callable, callable_name)
else:
callable()

def _fork_execute(self, callable, callable_name: str):
self.log.debug("Will fork to execute OpenLineage process.")
pid = os.fork()
if pid:
process = psutil.Process(pid)
try:
self.log.debug("Waiting for process %s", pid)
process.wait(conf.execution_timeout())
except psutil.TimeoutExpired:
self.log.warning(
"OpenLineage process %s expired. This should not affect process execution.", pid
)
process.kill()
except BaseException:
# Kill the process directly.
pass
try:
process.kill()
except Exception:
pass
self.log.warning("Process with pid %s finished - parent", pid)
else:
setproctitle(getproctitle() + " - OpenLineage - " + callable_name)
configure_orm(disable_connection_pool=True)
self.log.debug("Executing OpenLineage process - %s - pid %s", callable_name, os.getpid())
callable()
self.log.debug("Process with current pid finishes after %s", callable_name)
os._exit(0)

@property
def executor(self) -> ProcessPoolExecutor:
# Executor for dag_run listener
def initializer():
# Re-configure the ORM engine as there are issues with multiple processes
# if process calls Airflow DB.
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/openlineage/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-common-sql>=1.6.0
- attrs>=22.2
- openlineage-integration-common>=1.15.0
- openlineage-python>=1.15.0
- openlineage-integration-common>=1.16.0
- openlineage-python>=1.16.0

integrations:
- integration-name: OpenLineage
Expand Down Expand Up @@ -144,3 +144,10 @@ config:
example: ~
type: integer
version_added: 1.8.0
execution_timeout:
description: |
Maximum amount of time (in seconds) that OpenLineage can spend executing metadata extraction.
default: "10"
example: ~
type: integer
version_added: 1.9.0
16 changes: 12 additions & 4 deletions airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_table_schemas,
)
from airflow.typing_compat import TypedDict
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from sqlalchemy.engine import Engine
Expand Down Expand Up @@ -116,19 +117,26 @@ def from_table_meta(
return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper())


class SQLParser:
class SQLParser(LoggingMixin):
"""Interface for openlineage-sql.
:param dialect: dialect specific to the database
:param default_schema: schema applied to each table with no schema parsed
"""

def __init__(self, dialect: str | None = None, default_schema: str | None = None) -> None:
super().__init__()
self.dialect = dialect
self.default_schema = default_schema

def parse(self, sql: list[str] | str) -> SqlMeta | None:
"""Parse a single or a list of SQL statements."""
self.log.debug(
"OpenLineage calling SQL parser with SQL %s dialect %s schema %s",
sql,
self.dialect,
self.default_schema,
)
return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema)

def parse_table_schemas(
Expand All @@ -151,6 +159,7 @@ def parse_table_schemas(
"database": database or database_info.database,
"use_flat_cross_db_query": database_info.use_flat_cross_db_query,
}
self.log.info("PRE getting schemas for input and output tables")
return get_table_schemas(
hook,
namespace,
Expand Down Expand Up @@ -335,9 +344,8 @@ def split_statement(sql: str) -> list[str]:
return split_statement(sql)
return [obj for stmt in sql for obj in cls.split_sql_string(stmt) if obj != ""]

@classmethod
def create_information_schema_query(
cls,
self,
tables: list[DbTableMeta],
normalize_name: Callable[[str], str],
is_cross_db: bool,
Expand All @@ -349,7 +357,7 @@ def create_information_schema_query(
sqlalchemy_engine: Engine | None = None,
) -> str:
"""Create SELECT statement to query information schema table."""
tables_hierarchy = cls._get_tables_hierarchy(
tables_hierarchy = self._get_tables_hierarchy(
tables,
normalize_name=normalize_name,
database=database,
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/openlineage/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
from collections import defaultdict
from contextlib import closing
from enum import IntEnum
Expand All @@ -33,6 +34,9 @@
from airflow.hooks.base import BaseHook


log = logging.getLogger(__name__)


class ColumnIndex(IntEnum):
"""Enumerates the indices of columns in information schema view."""

Expand Down Expand Up @@ -90,6 +94,7 @@ def get_table_schemas(
if not in_query and not out_query:
return [], []

log.debug("Starting to query database for table schemas")
with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
if in_query:
cursor.execute(in_query)
Expand All @@ -101,6 +106,7 @@ def get_table_schemas(
out_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)]
else:
out_datasets = []
log.debug("Got table schema query result from database.")
return in_datasets, out_datasets


Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,10 @@ def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

connection = self.get_connection(getattr(self, self.conn_name_attr))
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))

if self.query_ids:
self.log.info("Getting connector to get database info :sadge:")
connection = self.get_connection(getattr(self, self.conn_name_attr))
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))
return OperatorLineage(
run_facets={
"externalQuery": ExternalQueryRunFacet(
Expand Down
4 changes: 2 additions & 2 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@
"apache-airflow-providers-common-sql>=1.6.0",
"apache-airflow>=2.7.0",
"attrs>=22.2",
"openlineage-integration-common>=1.15.0",
"openlineage-python>=1.15.0"
"openlineage-integration-common>=1.16.0",
"openlineage-python>=1.16.0"
],
"devel-deps": [],
"plugins": [
Expand Down
60 changes: 60 additions & 0 deletions tests/dags/test_openlineage_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
import time

from openlineage.client.generated.base import Dataset

from airflow.models.dag import DAG
from airflow.models.operator import BaseOperator
from airflow.providers.openlineage.extractors import OperatorLineage


class OpenLineageExecutionOperator(BaseOperator):
def __init__(self, *, stall_amount=0, **kwargs) -> None:
super().__init__(**kwargs)
self.stall_amount = stall_amount

def execute(self, context):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(1)

def get_openlineage_facets_on_start(self):
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")])

def get_openlineage_facets_on_complete(self, task_instance):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(self.stall_amount)
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")])


with DAG(
dag_id="test_openlineage_execution",
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
schedule="0 0 * * *",
dagrun_timeout=datetime.timedelta(minutes=60),
):
no_stall = OpenLineageExecutionOperator(task_id="execute_no_stall")

short_stall = OpenLineageExecutionOperator(task_id="execute_short_stall", stall_amount=5)

mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15)

long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30)
Loading

0 comments on commit 69865d6

Please sign in to comment.