Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add OpenLineage support for some SQL to GCS operators #45242

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.3.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.9.0",
"asgiref>=3.5.2",
Expand Down Expand Up @@ -970,7 +970,7 @@
},
"openlineage": {
"deps": [
"apache-airflow-providers-common-compat>=1.3.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.9.0",
"attrs>=22.2",
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/common/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "1.3.0"
__version__ = "1.4.0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.9.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql

else:
try:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
except ImportError:

def get_openlineage_facets_with_sql(
hook,
sql: str | list[str],
conn_id: str,
database: str | None,
):
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
log.debug("SQLParser could not be imported from OpenLineage provider.")
return None

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage


__all__ = ["get_openlineage_facets_with_sql"]
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ state: ready
source-date-epoch: 1731569875
# note that those versions are maintained by release manager - do not update them manually
versions:
- 1.4.0
- 1.3.0
- 1.2.2
- 1.2.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import base64
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

try:
from MySQLdb.constants import FIELD_TYPE
Expand All @@ -37,6 +39,9 @@
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class MySQLToGCSOperator(BaseSQLToGCSOperator):
"""
Expand Down Expand Up @@ -77,10 +82,13 @@ def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs)
self.mysql_conn_id = mysql_conn_id
self.ensure_utc = ensure_utc

@cached_property
def db_hook(self) -> MySqlHook:
return MySqlHook(mysql_conn_id=self.mysql_conn_id)

def query(self):
"""Query mysql and returns a cursor to the results."""
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
conn = mysql.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
if self.ensure_utc:
# Ensure TIMESTAMP results are in UTC
Expand Down Expand Up @@ -140,3 +148,20 @@ def convert_type(self, value, schema_type: str, **kwargs):
else:
value = base64.standard_b64encode(value).decode("ascii")
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.mysql_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
import time
import uuid
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

import pendulum
from slugify import slugify

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class _PostgresServerSideCursorDecorator:
"""
Expand Down Expand Up @@ -132,10 +137,13 @@ def _unique_name(self):
)
return None

@cached_property
def db_hook(self) -> PostgresHook:
return PostgresHook(postgres_conn_id=self.postgres_conn_id)

def query(self):
"""Query Postgres and returns a cursor to the results."""
hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
conn = hook.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor(name=self._unique_name())
cursor.execute(self.sql, self.parameters)
if self.use_server_side_cursor:
Expand Down Expand Up @@ -180,3 +188,20 @@ def convert_type(self, value, schema_type, stringify_dict=True):
if isinstance(value, Decimal):
return float(value)
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.postgres_conn_id,
database=self.db_hook.database,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.google.cloud.hooks.gcs import GCSHook

if TYPE_CHECKING:
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.utils.context import Context


Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self.partition_columns = partition_columns
self.write_on_empty = write_on_empty
self.parquet_row_group_size = parquet_row_group_size
self._uploaded_file_names: list[str] = []

def execute(self, context: Context):
if self.partition_columns:
Expand Down Expand Up @@ -501,3 +503,16 @@ def _upload_to_gcs(self, file_to_upload):
gzip=self.gzip if is_data_file else False,
metadata=metadata,
)
self._uploaded_file_names.append(object_name)

def _get_openlineage_output_datasets(self) -> list[OutputDataset]:
"""Retrieve OpenLineage output datasets."""
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path

return [
OutputDataset(
namespace=f"gs://{self.bucket}",
name=extract_ds_name_from_gcs_path(self.filename.split("{}", maxsplit=1)[0]),
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
Expand All @@ -26,6 +27,8 @@
from trino.client import TrinoResult
from trino.dbapi import Cursor as TrinoCursor

from airflow.providers.openlineage.extractors import OperatorLineage


class _TrinoToGCSTrinoCursorAdapter:
"""
Expand Down Expand Up @@ -181,10 +184,13 @@ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs):
super().__init__(**kwargs)
self.trino_conn_id = trino_conn_id

@cached_property
def db_hook(self) -> TrinoHook:
return TrinoHook(trino_conn_id=self.trino_conn_id)

def query(self):
"""Query trino and returns a cursor to the results."""
trino = TrinoHook(trino_conn_id=self.trino_conn_id)
conn = trino.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
self.log.info("Executing: %s", self.sql)
cursor.execute(self.sql)
Expand All @@ -207,3 +213,20 @@ def convert_type(self, value, schema_type, **kwargs):
:param schema_type: BigQuery data type
"""
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.trino_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ versions:

dependencies:
- apache-airflow>=2.9.0
- apache-airflow-providers-common-compat>=1.3.0
- apache-airflow-providers-common-compat>=1.4.0
- apache-airflow-providers-common-sql>=1.20.0
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ versions:
dependencies:
- apache-airflow>=2.9.0
- apache-airflow-providers-common-sql>=1.20.0
- apache-airflow-providers-common-compat>=1.3.0
- apache-airflow-providers-common-compat>=1.4.0
- attrs>=22.2
- openlineage-integration-common>=1.24.2
- openlineage-python>=1.24.2
Expand Down
Loading