Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 160 additions & 3 deletions airflow/providers/snowflake/transfers/copy_into_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ def __init__(
self.copy_options = copy_options
self.validation_mode = validation_mode

self.hook: SnowflakeHook | None = None
self._sql: str | None = None
self._result: list[dict[str, Any]] = []

def execute(self, context: Any) -> None:
snowflake_hook = SnowflakeHook(
self.hook = SnowflakeHook(
snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse,
database=self.database,
Expand All @@ -127,7 +131,7 @@ def execute(self, context: Any) -> None:
if self.columns_array:
into = f"{into}({', '.join(self.columns_array)})"

sql = f"""
self._sql = f"""
COPY INTO {into}
FROM @{self.stage}/{self.prefix or ""}
{"FILES=(" + ",".join(map(enclose_param, self.files)) + ")" if self.files else ""}
Expand All @@ -137,5 +141,158 @@ def execute(self, context: Any) -> None:
{self.validation_mode or ""}
"""
self.log.info("Executing COPY command...")
snowflake_hook.run(sql=sql, autocommit=self.autocommit)
self._result = self.hook.run( # type: ignore # mypy does not work well with return_dictionaries=True
sql=self._sql,
autocommit=self.autocommit,
handler=lambda x: x.fetchall(),
return_dictionaries=True,
)
self.log.info("COPY command completed")

@staticmethod
def _extract_openlineage_unique_dataset_paths(
query_result: list[dict[str, Any]],
) -> tuple[list[tuple[str, str]], list[str]]:
"""Extracts and returns unique OpenLineage dataset paths and file paths that failed to be parsed.

Each row in the results is expected to have a 'file' field, which is a URI.
The function parses these URIs and constructs a set of unique OpenLineage (namespace, name) tuples.
Additionally, it captures any URIs that cannot be parsed or processed
and returns them in a separate error list.

For Azure, Snowflake has a unique way of representing URI:
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
that is transformed by this function to a Dataset with more universal naming convention:
Dataset(namespace="wasbs://container_name@account_name", name="path/to"), as described at
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage

:param query_result: A list of dictionaries, each containing a 'file' key with a URI value.
:return: Two lists - the first is a sorted list of tuples, each representing a unique dataset path,
and the second contains any URIs that cannot be parsed or processed correctly.

>>> method = CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths

>>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
>>> method(results)
([('wasbs://azure_container@my_account', 'dir3')], [])

>>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container"}]
>>> method(results)
([('wasbs://azure_container@my_account', '/')], [])

>>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, {"file": "s3://bucket/a.csv"}]
>>> method(results)
([('gcs://bucket', '/'), ('s3://bucket', '/')], [])

>>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": "gcs://bucket/dir/dir2/a.txt"}]
>>> method(results)
([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])

>>> results = [
... {"file": "s3://bucket/dir/file.csv"},
... {"file": "azure://my_account.something_new.windows.net/azure_container"},
... ]
>>> method(results)
([('s3://bucket', 'dir')], ['azure://my_account.something_new.windows.net/azure_container'])
"""
import re
from pathlib import Path
from urllib.parse import urlparse

azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
extraction_error_files = []
unique_dataset_paths = set()

for row in query_result:
uri = urlparse(row["file"])
if uri.scheme == "azure":
match = re.fullmatch(azure_regex, row["file"])
if not match:
extraction_error_files.append(row["file"])
continue
account_name, container_name, name = match.groups()
namespace = f"wasbs://{container_name}@{account_name}"
else:
namespace = f"{uri.scheme}://{uri.netloc}"
name = uri.path.lstrip("/")

name = Path(name).parent.as_posix()
if name in ("", "."):
name = "/"

unique_dataset_paths.add((namespace, name))

return sorted(unique_dataset_paths), sorted(extraction_error_files)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement _on_complete because we rely on return value of a query."""
import re

from openlineage.client.facet import (
ExternalQueryRunFacet,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

if not self._sql:
return OperatorLineage()

query_results = self._result or []
# If no files were uploaded we get [{"status": "0 files were uploaded..."}]
if len(query_results) == 1 and query_results[0].get("status"):
query_results = []
unique_dataset_paths, extraction_error_files = self._extract_openlineage_unique_dataset_paths(
query_results
)
input_datasets = [Dataset(namespace=namespace, name=name) for namespace, name in unique_dataset_paths]

run_facets = {}
if extraction_error_files:
self.log.debug(
f"Unable to extract Dataset namespace and name "
f"for the following files: `{extraction_error_files}`."
)
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=len(query_results),
failedTasks=len(extraction_error_files),
errors=[
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task=file_uri,
taskNumber=None,
)
for file_uri in extraction_error_files
],
)

connection = self.hook.get_connection(getattr(self.hook, str(self.hook.conn_name_attr)))
database_info = self.hook.get_openlineage_database_info(connection)

dest_name = self.table
schema = self.hook.get_openlineage_default_schema()
database = database_info.database
if schema:
dest_name = f"{schema}.{dest_name}"
if database:
dest_name = f"{database}.{dest_name}"

snowflake_namespace = SQLParser.create_namespace(database_info)
query = SQLParser.normalize_sql(self._sql)
query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))

run_facets["externalQuery"] = ExternalQueryRunFacet(
externalQueryId=self.hook.query_ids[0], source=snowflake_namespace
)

return OperatorLineage(
inputs=input_datasets,
outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)],
job_facets={"sql": SqlJobFacet(query=query)},
run_facets=run_facets,
)
168 changes: 167 additions & 1 deletion tests/providers/snowflake/transfers/test_copy_into_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@
# under the License.
from __future__ import annotations

from typing import Callable
from unittest import mock

from openlineage.client.facet import (
ExternalQueryRunFacet,
ExtractionError,
ExtractionErrorRunFacet,
SqlJobFacet,
)
from openlineage.client.run import Dataset
from pytest import mark

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo
from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator


Expand Down Expand Up @@ -62,4 +74,158 @@ def test_execute(self, mock_hook):
validation_mode
"""

mock_hook.return_value.run.assert_called_once_with(sql=sql, autocommit=True)
mock_hook.return_value.run.assert_called_once_with(
sql=sql, autocommit=True, return_dictionaries=True, handler=mock.ANY
)

handler = mock_hook.return_value.run.mock_calls[0].kwargs.get("handler")
assert isinstance(handler, Callable)

@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
mock_hook().run.return_value = [
{"file": "s3://aws_bucket_name/dir1/file.csv"},
{"file": "s3://aws_bucket_name_2"},
{"file": "gcs://gcs_bucket_name/dir2/file.csv"},
{"file": "gcs://gcs_bucket_name_2"},
{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"},
{"file": "azure://my_account.blob.core.windows.net/azure_container_2"},
]
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_inputs = [
Dataset(namespace="gcs://gcs_bucket_name", name="dir2"),
Dataset(namespace="gcs://gcs_bucket_name_2", name="/"),
Dataset(namespace="s3://aws_bucket_name", name="dir1"),
Dataset(namespace="s3://aws_bucket_name_2", name="/"),
Dataset(namespace="wasbs://azure_container@my_account", name="dir3"),
Dataset(namespace="wasbs://azure_container_2@my_account", name="/"),
]
expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=expected_inputs,
outputs=expected_outputs,
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
)
},
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)

@mark.parametrize("rows", (None, []))
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete_with_empty_inputs(self, mock_hook, rows):
mock_hook().run.return_value = rows
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=[],
outputs=expected_outputs,
run_facets={
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
)
},
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)

@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hook):
mock_hook().run.return_value = [
{"file": "s3://aws_bucket_name/dir1/file.csv"},
{"file": "gs://gcp_bucket_name/dir2/file.csv"},
{"file": "azure://my_account.weird-url.net/azure_container/dir3/file.csv"},
{"file": "azure://my_account.another_weird-url.net/con/file.csv"},
]
mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
scheme="snowflake_scheme", authority="authority", database="actual_database"
)
mock_hook().get_openlineage_default_schema.return_value = "actual_schema"
mock_hook().query_ids = ["query_id_123"]

expected_inputs = [
Dataset(namespace="gs://gcp_bucket_name", name="dir2"),
Dataset(namespace="s3://aws_bucket_name", name="dir1"),
]
expected_outputs = [
Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table")
]
expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV"""
expected_run_facets = {
"extractionError": ExtractionErrorRunFacet(
totalTasks=4,
failedTasks=2,
errors=[
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task="azure://my_account.another_weird-url.net/con/file.csv",
taskNumber=None,
),
ExtractionError(
errorMessage="Unable to extract Dataset namespace and name.",
stackTrace=None,
task="azure://my_account.weird-url.net/azure_container/dir3/file.csv",
taskNumber=None,
),
],
),
"externalQuery": ExternalQueryRunFacet(
externalQueryId="query_id_123", source="snowflake_scheme://authority"
),
}

op = CopyFromExternalStageToSnowflakeOperator(
task_id="test",
table="table",
stage="stage",
database="",
schema="schema",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage(
inputs=expected_inputs,
outputs=expected_outputs,
run_facets=expected_run_facets,
job_facets={"sql": SqlJobFacet(query=expected_sql)},
)