Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
from urllib.parse import urlparse

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand All @@ -36,6 +37,9 @@
from airflow.utils.strings import to_boolean

T = TypeVar("T")
if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo


def _try_to_boolean(value: Any):
Expand Down Expand Up @@ -448,3 +452,68 @@ def _get_cursor(self, conn: Any, return_dictionaries: bool):
finally:
if cursor is not None:
cursor.close()

def get_openlineage_database_info(self, connection) -> DatabaseInfo:
from airflow.providers.openlineage.sqlparser import DatabaseInfo

database = self.database or self._get_field(connection.extra_dejson, "database")

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=self._get_openlineage_authority(connection),
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"data_type",
],
database=database,
is_information_schema_cross_db=True,
is_uppercase_names=True,
)

def get_openlineage_database_dialect(self, _) -> str:
return "snowflake"

def get_openlineage_default_schema(self) -> str | None:
"""
Attempts to get current schema.

Usually ``SELECT CURRENT_SCHEMA();`` should work.
However, apparently you may set ``database`` without ``schema``
and get results from ``SELECT CURRENT_SCHEMAS();`` but not
from ``SELECT CURRENT_SCHEMA();``.
It still may return nothing if no database is set in connection.
"""
schema = self._get_conn_params()["schema"]
if not schema:
current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
if current_schemas:
_, schema = current_schemas.split(".")
return schema

def _get_openlineage_authority(self, _) -> str:
from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri

uri = fix_snowflake_sqlalchemy_uri(self.get_uri())
return urlparse(uri).hostname

def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None:
from openlineage.client.facet import ExternalQueryRunFacet

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

connection = self.get_connection(getattr(self, self.conn_name_attr))
namespace = SQLParser.create_namespace(self.get_database_info(connection))

if self.query_ids:
return OperatorLineage(
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId=self.query_ids[0], source=namespace
)
}
)
return None
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@
],
"cross-providers-deps": [
"common.sql",
"openlineage",
"slack"
],
"excluded-python-versions": []
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,33 @@ def test___ensure_prefixes(self):
"extra__snowflake__private_key_content",
"extra__snowflake__insecure_mode",
]

@pytest.mark.parametrize(
"returned_schema,expected_schema",
[([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
)
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_no_schema_set(
self, mock_get_first, returned_schema, expected_schema
):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"schema": None,
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
mock_get_first.return_value = returned_schema
assert hook.get_openlineage_default_schema() == expected_schema

@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
assert hook.get_openlineage_default_schema() == BASE_CONNECTION_KWARGS["schema"]
mock_get_first.assert_not_called()

hook_with_schema_param = SnowflakeHook(snowflake_conn_id="test_conn", schema="my_schema")
assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema"
mock_get_first.assert_not_called()
68 changes: 68 additions & 0 deletions tests/providers/snowflake/operators/test_snowflake_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

import pytest
from databricks.sql.types import Row
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
from openlineage.client.run import Dataset

from airflow.models.connection import Connection
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator

DATE = "2017-04-20"
Expand Down Expand Up @@ -138,3 +142,67 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
return_last=return_last,
split_statements=split_statement,
)


def test_execute_openlineage_events():
DB_NAME = "DATABASE"
DB_SCHEMA_NAME = "PUBLIC"

class SnowflakeHookForTests(SnowflakeHook):
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_first(self, *_):
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = SnowflakeHookForTests()

class SnowflakeOperatorForTest(SnowflakeOperator):
def get_db_hook(self):
return dbapi_hook

sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
order_day_of_week VARCHAR(64) NOT NULL,
order_placed_on TIMESTAMP NOT NULL,
orders_placed INTEGER NOT NULL
);
FORGOT TO COMMENT"""
op = SnowflakeOperatorForTest(task_id="snowflake-operator", sql=sql)
rows = [
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_DAY_OF_WEEK", 1, "TEXT"),
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_PLACED_ON", 2, "TIMESTAMP_NTZ"),
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDERS_PLACED", 3, "NUMBER"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
extra={
"account": "test_account",
"region": "us-east",
"warehouse": "snow-warehouse",
"database": DB_NAME,
},
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 0
assert lineage.outputs == [
Dataset(
namespace="snowflake://test_account.us-east.aws",
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.POPULAR_ORDERS_DAY_OF_WEEK",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
SchemaField(name="ORDER_PLACED_ON", type="TIMESTAMP_NTZ"),
SchemaField(name="ORDERS_PLACED", type="NUMBER"),
]
)
},
)
]

assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

assert lineage.run_facets["extractionError"].failedTasks == 1