Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

import sqlparse
from attrs import define
from openlineage.client.facet import BaseFacet, ExtractionError, ExtractionErrorRunFacet, SqlJobFacet
from openlineage.client.facet import (
BaseFacet,
ColumnLineageDatasetFacet,
ColumnLineageDatasetFacetFieldsAdditional,
ColumnLineageDatasetFacetFieldsAdditionalInputFields,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.common.sql import DbTableMeta, SqlMeta, parse

from airflow.providers.openlineage.extractors.base import OperatorLineage
Expand Down Expand Up @@ -143,6 +151,47 @@ def parse_table_schemas(
else None,
)

def attach_column_lineage(
self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta
) -> None:
"""
Attaches column lineage facet to the list of datasets.

Note that currently each dataset has the same column lineage information set.
This would be a matter of change after OpenLineage SQL Parser improvements.
"""
if not len(parse_result.column_lineage):
return
for dataset in datasets:
dataset.facets["columnLineage"] = ColumnLineageDatasetFacet(
fields={
column_lineage.descendant.name: ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=dataset.namespace,
name=".".join(
filter(
None,
(
column_meta.origin.database or database,
column_meta.origin.schema or self.default_schema,
column_meta.origin.name,
),
)
)
if column_meta.origin
else "",
field=column_meta.name,
)
for column_meta in column_lineage.lineage
],
transformationType="",
transformationDescription="",
)
for column_lineage in parse_result.column_lineage
}
)

def generate_openlineage_metadata_from_sql(
self,
sql: list[str] | str,
Expand Down Expand Up @@ -198,6 +247,8 @@ def generate_openlineage_metadata_from_sql(
sqlalchemy_engine=sqlalchemy_engine,
)

self.attach_column_lineage(outputs, database or database_info.database, parse_result)

return OperatorLineage(
inputs=inputs,
outputs=outputs,
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/providers/openlineage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
16 changes: 16 additions & 0 deletions tests/integration/providers/openlineage/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
181 changes: 131 additions & 50 deletions tests/providers/openlineage/utils/test_sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from unittest.mock import MagicMock

import pytest
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
from openlineage.client.facet import (
ColumnLineageDatasetFacet,
ColumnLineageDatasetFacetFieldsAdditional,
ColumnLineageDatasetFacetFieldsAdditionalInputFields,
SchemaDatasetFacet,
SchemaField,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from openlineage.common.sql import DbTableMeta

Expand All @@ -33,16 +40,6 @@

NAMESPACE = "test_namespace"

SCHEMA_FACET = SchemaDatasetFacet(
fields=[
SchemaField(name="ID", type="int4"),
SchemaField(name="AMOUNT_OFF", type="int4"),
SchemaField(name="CUSTOMER_EMAIL", type="varchar"),
SchemaField(name="STARTS_ON", type="timestamp"),
SchemaField(name="ENDS_ON", type="timestamp"),
]
)


def normalize_name_lower(name: str) -> str:
return name.lower()
Expand All @@ -56,7 +53,8 @@ def test_get_tables_hierarchy(self):

# base check with db, no cross db
assert SQLParser._get_tables_hierarchy(
[DbTableMeta("Db.Schema1.Table1"), DbTableMeta("Db.Schema2.Table2")], normalize_name_lower
[DbTableMeta("Db.Schema1.Table1"), DbTableMeta("Db.Schema2.Table2")],
normalize_name_lower,
) == {None: {"schema1": ["Table1"], "schema2": ["Table2"]}}

# same, with cross db
Expand Down Expand Up @@ -148,20 +146,42 @@ def rows(name):
]

hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [
rows("TABLE_IN"),
rows("TABLE_OUT"),
rows("top_delivery_times"),
rows("popular_orders_day_of_week"),
]

expected_schema_facet = SchemaDatasetFacet(
fields=[
SchemaField(name="ID", type="int4"),
SchemaField(name="AMOUNT_OFF", type="int4"),
SchemaField(name="CUSTOMER_EMAIL", type="varchar"),
SchemaField(name="STARTS_ON", type="timestamp"),
SchemaField(name="ENDS_ON", type="timestamp"),
]
)

expected = (
[Dataset(namespace=NAMESPACE, name="PUBLIC.TABLE_IN", facets={"schema": SCHEMA_FACET})],
[Dataset(namespace=NAMESPACE, name="PUBLIC.TABLE_OUT", facets={"schema": SCHEMA_FACET})],
[
Dataset(
namespace=NAMESPACE,
name="PUBLIC.top_delivery_times",
facets={"schema": expected_schema_facet},
)
],
[
Dataset(
namespace=NAMESPACE,
name="PUBLIC.popular_orders_day_of_week",
facets={"schema": expected_schema_facet},
)
],
)

assert expected == parser.parse_table_schemas(
hook=hook,
namespace=NAMESPACE,
inputs=[DbTableMeta("TABLE_IN")],
outputs=[DbTableMeta("TABLE_OUT")],
inputs=[DbTableMeta("top_delivery_times")],
outputs=[DbTableMeta("popular_orders_day_of_week")],
database_info=db_info,
)

Expand All @@ -173,64 +193,125 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns

hook = MagicMock()

def rows(schema, table):
return [
(schema, table, "ID", 1, "int4"),
(schema, table, "AMOUNT_OFF", 2, "int4"),
(schema, table, "CUSTOMER_EMAIL", 3, "varchar"),
(schema, table, "STARTS_ON", 4, "timestamp"),
(schema, table, "ENDS_ON", 5, "timestamp"),
]
returned_schema = DB_SCHEMA_NAME if parser_returns_schema else None
returned_rows = [
[
(returned_schema, "top_delivery_times", "order_id", 1, "int4"),
(
returned_schema,
"top_delivery_times",
"order_placed_on",
2,
"timestamp",
),
(returned_schema, "top_delivery_times", "customer_email", 3, "varchar"),
],
[
(
returned_schema,
"popular_orders_day_of_week",
"order_day_of_week",
1,
"varchar",
),
(
returned_schema,
"popular_orders_day_of_week",
"order_placed_on",
2,
"timestamp",
),
(
returned_schema,
"popular_orders_day_of_week",
"orders_placed",
3,
"int4",
),
],
]

sql = """CREATE TABLE table_out (
ID int,
AMOUNT_OFF int,
CUSTOMER_EMAIL varchar,
STARTS_ON timestamp,
ENDS_ON timestamp
sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times
--irrelevant comment
)
;
"""

hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [
rows(DB_SCHEMA_NAME if parser_returns_schema else None, "TABLE_IN"),
rows(DB_SCHEMA_NAME if parser_returns_schema else None, "TABLE_OUT"),
]
hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = returned_rows

mock_sql_meta = MagicMock()
if parser_returns_schema:
mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.TABLE_IN")]
mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.TABLE_OUT")]
mock_sql_meta.in_tables = [DbTableMeta("PUBLIC.top_delivery_times")]
mock_sql_meta.out_tables = [DbTableMeta("PUBLIC.popular_orders_day_of_week")]
else:
mock_sql_meta.in_tables = [DbTableMeta("TABLE_IN")]
mock_sql_meta.out_tables = [DbTableMeta("TABLE_OUT")]
mock_sql_meta.in_tables = [DbTableMeta("top_delivery_times")]
mock_sql_meta.out_tables = [DbTableMeta("popular_orders_day_of_week")]
mock_column_lineage = MagicMock()
mock_column_lineage.descendant.name = "order_day_of_week"
mock_lineage = MagicMock()
mock_lineage.name = "order_placed_on"
mock_lineage.origin.name = "top_delivery_times"
mock_lineage.origin.database = None
mock_lineage.origin.schema = "PUBLIC" if parser_returns_schema else None
mock_column_lineage.lineage = [mock_lineage]

mock_sql_meta.column_lineage = [mock_column_lineage]
mock_sql_meta.errors = []

mock_parse.return_value = mock_sql_meta

formatted_sql = """CREATE TABLE table_out (
ID int,
AMOUNT_OFF int,
CUSTOMER_EMAIL varchar,
STARTS_ON timestamp,
ENDS_ON timestamp
formatted_sql = """INSERT INTO popular_orders_day_of_week (order_day_of_week)
SELECT EXTRACT(ISODOW FROM order_placed_on) AS order_day_of_week
FROM top_delivery_times

)"""
expected_schema = "PUBLIC" if parser_returns_schema else "ANOTHER_SCHEMA"
expected = OperatorLineage(
inputs=[
Dataset(
namespace="myscheme://host:port",
name=f"{expected_schema}.TABLE_IN",
facets={"schema": SCHEMA_FACET},
name=f"{expected_schema}.top_delivery_times",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_id", type="int4"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="customer_email", type="varchar"),
]
)
},
)
],
outputs=[
Dataset(
namespace="myscheme://host:port",
name=f"{expected_schema}.TABLE_OUT",
facets={"schema": SCHEMA_FACET},
name=f"{expected_schema}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"order_day_of_week": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace="myscheme://host:port",
name=f"{expected_schema}.top_delivery_times",
field="order_placed_on",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
],
job_facets={"sql": SqlJobFacet(query=formatted_sql)},
Expand Down