Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@
}
],
"cross-providers-deps": [
"common.sql"
"common.compat",
"common.sql",
"openlineage"
],
"excluded-python-versions": [],
"state": "ready"
Expand Down
171 changes: 168 additions & 3 deletions providers/src/airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,12 @@ def __init__(
if force_copy is not None:
self._copy_options["force"] = "true" if force_copy else "false"

# These will be used by OpenLineage
self._sql: str | None = None
self._result: list[Any] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is no longer needed?


def _get_hook(self) -> DatabricksSqlHook:
"""Get a DatabricksSqlHook properly configured for this operator."""
return DatabricksSqlHook(
self.databricks_conn_id,
http_path=self._http_path,
Expand All @@ -293,6 +298,11 @@ def _generate_options(
opts: dict[str, str] | None = None,
escape_key: bool = True,
) -> str:
"""
Generate the bracketed options clause for the COPY INTO command.

Example: FORMAT_OPTIONS (header = 'true', inferSchema = 'true').
"""
formatted_opts = ""
if opts:
pairs = [
Expand All @@ -304,6 +314,7 @@ def _generate_options(
return formatted_opts

def _create_sql_query(self) -> str:
"""Create the COPY INTO statement from the provided options."""
escaper = ParamEscaper()
maybe_with = ""
if self._encryption is not None or self._credential is not None:
Expand Down Expand Up @@ -349,12 +360,166 @@ def _create_sql_query(self) -> str:
return sql.strip()

def execute(self, context: Context) -> Any:
sql = self._create_sql_query()
self.log.info("Executing: %s", sql)
"""Execute the COPY INTO command and store the result for lineage reporting."""
self._sql = self._create_sql_query()
self.log.info("Executing SQL: %s", self._sql)

hook = self._get_hook()
hook.run(sql)
hook.run(self._sql)

def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling gets
# handled in `DatabricksSqlHook.run()` method which is called in `execute()`
...

def _parse_input_dataset(self) -> tuple[list[Any], list[Any]]:
"""Parse file_location to build the input dataset."""
from airflow.providers.common.compat.openlineage.facet import Dataset, Error

input_datasets: list[Dataset] = []
extraction_errors: list[Error] = []

if not self.file_location:
return input_datasets, extraction_errors

try:
from urllib.parse import urlparse

parsed_uri = urlparse(self.file_location)
# Only process known schemes
if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs", "azure", "abfss", "wasbs"):
raise ValueError(f"Unsupported scheme: {parsed_uri.scheme}")

scheme = parsed_uri.scheme
namespace = f"{scheme}://{parsed_uri.netloc}"
path = parsed_uri.path.lstrip("/") or "/"
input_datasets.append(Dataset(namespace=namespace, name=path))
except Exception as e:
self.log.error("Failed to parse file_location: %s, error: %s", self.file_location, str(e))
extraction_errors.append(
Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None)
)

return input_datasets, extraction_errors

def _create_sql_job_facet(self) -> tuple[dict, list[Any]]:
"""Create SQL job facet from the SQL query."""
from airflow.providers.common.compat.openlineage.facet import Error, SQLJobFacet
from airflow.providers.openlineage.sqlparser import SQLParser

job_facets = {}
extraction_errors: list[Error] = []

try:
import re

normalized_sql = SQLParser.normalize_sql(self._sql)
normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ", normalized_sql))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually only use SQLParser.normalize_sql for the SQLJobFacet. What is the reason for this additional replacements? Could you add some comments if it's necessary ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done in CopyFromExternalStageToSnowflakeOperator OL implementation here

query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))

job_facets["sql"] = SQLJobFacet(query=normalized_sql)
except Exception as e:
self.log.error("Failed creating SQL job facet: %s", str(e))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually try not to log on error level unless absolutely necessary. Could you review the code and adjust in other places as well? Maybe warning is enough? WDYT?

extraction_errors.append(
Error(errorMessage=str(e), stackTrace=None, task="sql_facet_creation", taskNumber=None)
)

return job_facets, extraction_errors

def _build_output_dataset(self) -> tuple[Any, list[Any]]:
"""Build output dataset from table information."""
from airflow.providers.common.compat.openlineage.facet import Dataset, Error

output_dataset = None
extraction_errors: list[Error] = []

if not self.table_name:
return output_dataset, extraction_errors

try:
table_parts = self.table_name.split(".")
if len(table_parts) == 3: # catalog.schema.table
catalog, schema, table = table_parts
elif len(table_parts) == 2: # schema.table
catalog = None
schema, table = table_parts
else:
catalog = None
schema = None
table = self.table_name

hook = self._get_hook()
conn = hook.get_connection(hook.databricks_conn_id)
output_namespace = f"databricks://{conn.host}"

# Combine schema/table with optional catalog for final dataset name
fq_name = table
if schema:
fq_name = f"{schema}.{fq_name}"
if catalog:
fq_name = f"{catalog}.{fq_name}"

output_dataset = Dataset(namespace=output_namespace, name=fq_name)
except Exception as e:
self.log.error("Failed to construct output dataset: %s", str(e))
extraction_errors.append(
Error(
errorMessage=str(e),
stackTrace=None,
task="output_dataset_construction",
taskNumber=None,
)
)

return output_dataset, extraction_errors

def get_openlineage_facets_on_complete(self, task_instance):
"""
Compute OpenLineage facets for the COPY INTO command.

Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and build an
input dataset list and an output dataset (the Delta table).
"""
from airflow.providers.common.compat.openlineage.facet import ExtractionErrorRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage

if not self._sql:
self.log.warning("No SQL query found, returning empty OperatorLineage.")
return OperatorLineage()

# Get input datasets and any parsing errors
input_datasets, extraction_errors = self._parse_input_dataset()

# Create SQL job facet
job_facets, sql_errors = self._create_sql_job_facet()
extraction_errors.extend(sql_errors)

run_facets = {}
if extraction_errors:
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=1,
failedTasks=len(extraction_errors),
errors=extraction_errors,
)
# Return only error facets for invalid URIs
return OperatorLineage(
inputs=[],
outputs=[],
job_facets=job_facets,
run_facets=run_facets,
)
Comment on lines +496 to +508
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we try to return the output dataset even if the inputs are incorrect?


# Build output dataset
output_dataset, output_errors = self._build_output_dataset()
if output_errors:
extraction_errors.extend(output_errors)
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=1,
failedTasks=len(extraction_errors),
errors=extraction_errors,
)

return OperatorLineage(
inputs=input_datasets,
outputs=[output_dataset] if output_dataset else [],
job_facets=job_facets,
run_facets=run_facets,
)
160 changes: 156 additions & 4 deletions providers/tests/databricks/operators/test_databricks_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
SQLJobFacet,
)
from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator
from airflow.providers.openlineage.extractors import OperatorLineage

DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
Expand Down Expand Up @@ -140,10 +147,8 @@ def test_copy_with_encryption_and_credential():
assert (
op._create_sql_query()
== f"""COPY INTO test
FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """
"""ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
FILEFORMAT = CSV
""".strip()
FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc'))
FILEFORMAT = CSV""".strip()
)


Expand Down Expand Up @@ -253,3 +258,150 @@ def test_templating(create_task_instance_of_operator, session):
assert task.files == "files"
assert task.table_name == "table-name"
assert task.databricks_conn_id == "databricks-conn-id"


@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
def test_get_openlineage_facets_on_complete_s3(mock_hook):
"""Test OpenLineage facets generation for S3 source."""
mock_hook().run.return_value = [
{"file": "s3://bucket/dir1/file1.csv"},
{"file": "s3://bucket/dir1/file2.csv"},
]
mock_hook().get_connection().host = "databricks.com"

op = DatabricksCopyIntoOperator(
task_id="test",
table_name="schema.table",
file_location="s3://bucket/dir1",
file_format="CSV",
)
op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'"
op._result = mock_hook().run.return_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is already gone from the operator, so the tests should be adjusted?


lineage = op.get_openlineage_facets_on_complete(None)

assert lineage == OperatorLineage(
inputs=[Dataset(namespace="s3://bucket", name="dir1")],
outputs=[Dataset(namespace="databricks://databricks.com", name="schema.table")],
job_facets={"sql": SQLJobFacet(query="COPY INTO schema.table FROM 's3://bucket/dir1'")},
run_facets={},
)


@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
def test_get_openlineage_facets_on_complete_with_errors(mock_hook):
"""Test OpenLineage facets generation with extraction errors."""
mock_hook().run.return_value = [
{"file": "s3://bucket/dir1/file1.csv"},
{"file": "invalid://location/file.csv"}, # Invalid URI
{"file": "azure://account.invalid.windows.net/container/file.csv"}, # Invalid Azure URI
]
mock_hook().get_connection().host = "databricks.com"
Comment on lines +292 to +299
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are passing invalid URI's and then checking that there are no extraction errors? Is this test valid?


op = DatabricksCopyIntoOperator(
task_id="test",
table_name="schema.table",
file_location="s3://bucket/dir1",
file_format="CSV",
)
op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'"
op._result = mock_hook().run.return_value

lineage = op.get_openlineage_facets_on_complete(None)

# Check inputs and outputs
assert len(lineage.inputs) == 1
assert lineage.inputs[0].namespace == "s3://bucket"
assert lineage.inputs[0].name == "dir1"

assert len(lineage.outputs) == 1
assert lineage.outputs[0].namespace == "databricks://databricks.com"
assert lineage.outputs[0].name == "schema.table"

# Check facets exist and have correct structure
assert "sql" in lineage.job_facets
assert lineage.job_facets["sql"].query == "COPY INTO schema.table FROM 's3://bucket/dir1'"

assert "extractionError" not in lineage.run_facets


@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
def test_get_openlineage_facets_on_complete_no_sql(mock_hook):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should run the execute and then explicitly overwrite the self._sql? Or at least manually make sure the _sql is None. This test assumes the self._sql is initiated as None, but we don't check it.

"""Test OpenLineage facets generation when no SQL is available."""
op = DatabricksCopyIntoOperator(
task_id="test",
table_name="schema.table",
file_location="s3://bucket/dir1",
file_format="CSV",
)

lineage = op.get_openlineage_facets_on_complete(None)
assert lineage == OperatorLineage()


@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
def test_get_openlineage_facets_on_complete_gcs(mock_hook):
"""Test OpenLineage facets generation specifically for GCS paths."""
mock_hook().run.return_value = [
{"file": "gs://bucket1/dir1/file1.csv"},
{"file": "gs://bucket1/dir2/nested/file2.csv"},
{"file": "gs://bucket2/file3.csv"},
{"file": "gs://bucket2"}, # Edge case: root path
{"file": "gs://invalid-bucket/@#$%"}, # Invalid path
]
mock_hook().get_connection.return_value.host = "databricks.com"
mock_hook().query_ids = ["query_123"]

op = DatabricksCopyIntoOperator(
task_id="test",
table_name="catalog.schema.table",
file_location="gs://location",
file_format="CSV",
)
op.execute(None)
result = op.get_openlineage_facets_on_complete(None)

# Check inputs - only one input from file_location
assert len(result.inputs) == 1
assert result.inputs[0].namespace == "gs://location"
assert result.inputs[0].name == "/"

# Check outputs
assert len(result.outputs) == 1
assert result.outputs[0].namespace == "databricks://databricks.com"
assert result.outputs[0].name == "catalog.schema.table"

# Check SQL job facet
assert "sql" in result.job_facets
assert "COPY INTO catalog.schema.table" in result.job_facets["sql"].query
assert "FILEFORMAT = CSV" in result.job_facets["sql"].query
Comment on lines +376 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should check the whole query or at least also check if the gcs path is there? WDYT?



@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
def test_get_openlineage_facets_on_complete_invalid_location(mock_hook):
"""Test OpenLineage facets generation with invalid file_location."""
mock_hook().get_connection().host = "databricks.com"

op = DatabricksCopyIntoOperator(
task_id="test",
table_name="schema.table",
file_location="invalid://location", # Invalid location
file_format="CSV",
)
op._sql = "COPY INTO schema.table FROM 'invalid://location'"
op._result = [{"file": "s3://bucket/file.csv"}]
Comment on lines +391 to +392
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not actually execute the operator instead?


lineage = op.get_openlineage_facets_on_complete(None)

# Should have no inputs due to invalid location
assert len(lineage.inputs) == 0

# Should not have output and SQL facets
assert len(lineage.outputs) == 0
assert "sql" in lineage.job_facets

# Should have extraction error facet
assert "extractionError" in lineage.run_facets
assert lineage.run_facets["extractionError"].totalTasks == 1
assert lineage.run_facets["extractionError"].failedTasks == 1
assert len(lineage.run_facets["extractionError"].errors) == 1
Loading