-
Notifications
You must be signed in to change notification settings - Fork 16.4k
[OpenLineage] Added Openlineage support for DatabricksCopyIntoOperator #45257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6e62946
19caed0
1daf54a
04fa4be
daac349
3f40df2
c8e8fd3
a42f1b8
cb75d12
34e0cdf
4ea833a
e5cb9f6
44c3b09
5a163fb
1a04449
4ed4f57
50b902f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -273,7 +273,12 @@ def __init__( | |||
| if force_copy is not None: | ||||
| self._copy_options["force"] = "true" if force_copy else "false" | ||||
|
|
||||
| # These will be used by OpenLineage | ||||
| self._sql: str | None = None | ||||
| self._result: list[Any] = [] | ||||
|
|
||||
| def _get_hook(self) -> DatabricksSqlHook: | ||||
| """Get a DatabricksSqlHook properly configured for this operator.""" | ||||
| return DatabricksSqlHook( | ||||
| self.databricks_conn_id, | ||||
| http_path=self._http_path, | ||||
|
|
@@ -293,6 +298,11 @@ def _generate_options( | |||
| opts: dict[str, str] | None = None, | ||||
| escape_key: bool = True, | ||||
| ) -> str: | ||||
| """ | ||||
| Generate the bracketed options clause for the COPY INTO command. | ||||
|
|
||||
| Example: FORMAT_OPTIONS (header = 'true', inferSchema = 'true'). | ||||
| """ | ||||
| formatted_opts = "" | ||||
| if opts: | ||||
| pairs = [ | ||||
|
|
@@ -304,6 +314,7 @@ def _generate_options( | |||
| return formatted_opts | ||||
|
|
||||
| def _create_sql_query(self) -> str: | ||||
| """Create the COPY INTO statement from the provided options.""" | ||||
| escaper = ParamEscaper() | ||||
| maybe_with = "" | ||||
| if self._encryption is not None or self._credential is not None: | ||||
|
|
@@ -349,12 +360,166 @@ def _create_sql_query(self) -> str: | |||
| return sql.strip() | ||||
|
|
||||
| def execute(self, context: Context) -> Any: | ||||
| sql = self._create_sql_query() | ||||
| self.log.info("Executing: %s", sql) | ||||
| """Execute the COPY INTO command and store the result for lineage reporting.""" | ||||
| self._sql = self._create_sql_query() | ||||
| self.log.info("Executing SQL: %s", self._sql) | ||||
|
|
||||
| hook = self._get_hook() | ||||
| hook.run(sql) | ||||
| hook.run(self._sql) | ||||
|
|
||||
| def on_kill(self) -> None: | ||||
| # NB: on_kill isn't required for this operator since query cancelling gets | ||||
| # handled in `DatabricksSqlHook.run()` method which is called in `execute()` | ||||
| ... | ||||
|
|
||||
| def _parse_input_dataset(self) -> tuple[list[Any], list[Any]]: | ||||
| """Parse file_location to build the input dataset.""" | ||||
| from airflow.providers.common.compat.openlineage.facet import Dataset, Error | ||||
|
|
||||
| input_datasets: list[Dataset] = [] | ||||
| extraction_errors: list[Error] = [] | ||||
|
|
||||
| if not self.file_location: | ||||
| return input_datasets, extraction_errors | ||||
|
|
||||
| try: | ||||
| from urllib.parse import urlparse | ||||
|
|
||||
| parsed_uri = urlparse(self.file_location) | ||||
| # Only process known schemes | ||||
| if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs", "azure", "abfss", "wasbs"): | ||||
| raise ValueError(f"Unsupported scheme: {parsed_uri.scheme}") | ||||
|
|
||||
| scheme = parsed_uri.scheme | ||||
| namespace = f"{scheme}://{parsed_uri.netloc}" | ||||
| path = parsed_uri.path.lstrip("/") or "/" | ||||
| input_datasets.append(Dataset(namespace=namespace, name=path)) | ||||
| except Exception as e: | ||||
| self.log.error("Failed to parse file_location: %s, error: %s", self.file_location, str(e)) | ||||
| extraction_errors.append( | ||||
| Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None) | ||||
| ) | ||||
|
|
||||
| return input_datasets, extraction_errors | ||||
|
|
||||
| def _create_sql_job_facet(self) -> tuple[dict, list[Any]]: | ||||
| """Create SQL job facet from the SQL query.""" | ||||
| from airflow.providers.common.compat.openlineage.facet import Error, SQLJobFacet | ||||
| from airflow.providers.openlineage.sqlparser import SQLParser | ||||
|
|
||||
| job_facets = {} | ||||
| extraction_errors: list[Error] = [] | ||||
|
|
||||
| try: | ||||
| import re | ||||
|
|
||||
| normalized_sql = SQLParser.normalize_sql(self._sql) | ||||
| normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ", normalized_sql)) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we usually only use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is done in CopyFromExternalStageToSnowflakeOperator OL implementation here airflow/providers/src/airflow/providers/snowflake/transfers/copy_into_snowflake.py Line 287 in c600a95
|
||||
| job_facets["sql"] = SQLJobFacet(query=normalized_sql) | ||||
| except Exception as e: | ||||
| self.log.error("Failed creating SQL job facet: %s", str(e)) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we usually try not to log on error level unless absolutely necessary. Could you review the code and adjust in other places as well? Maybe warning is enough? WDYT? |
||||
| extraction_errors.append( | ||||
| Error(errorMessage=str(e), stackTrace=None, task="sql_facet_creation", taskNumber=None) | ||||
| ) | ||||
|
|
||||
| return job_facets, extraction_errors | ||||
|
|
||||
| def _build_output_dataset(self) -> tuple[Any, list[Any]]: | ||||
| """Build output dataset from table information.""" | ||||
| from airflow.providers.common.compat.openlineage.facet import Dataset, Error | ||||
|
|
||||
| output_dataset = None | ||||
| extraction_errors: list[Error] = [] | ||||
|
|
||||
| if not self.table_name: | ||||
| return output_dataset, extraction_errors | ||||
|
|
||||
| try: | ||||
| table_parts = self.table_name.split(".") | ||||
| if len(table_parts) == 3: # catalog.schema.table | ||||
| catalog, schema, table = table_parts | ||||
| elif len(table_parts) == 2: # schema.table | ||||
| catalog = None | ||||
| schema, table = table_parts | ||||
| else: | ||||
| catalog = None | ||||
| schema = None | ||||
| table = self.table_name | ||||
|
|
||||
| hook = self._get_hook() | ||||
| conn = hook.get_connection(hook.databricks_conn_id) | ||||
| output_namespace = f"databricks://{conn.host}" | ||||
|
|
||||
| # Combine schema/table with optional catalog for final dataset name | ||||
| fq_name = table | ||||
| if schema: | ||||
| fq_name = f"{schema}.{fq_name}" | ||||
| if catalog: | ||||
| fq_name = f"{catalog}.{fq_name}" | ||||
|
|
||||
| output_dataset = Dataset(namespace=output_namespace, name=fq_name) | ||||
| except Exception as e: | ||||
| self.log.error("Failed to construct output dataset: %s", str(e)) | ||||
| extraction_errors.append( | ||||
| Error( | ||||
| errorMessage=str(e), | ||||
| stackTrace=None, | ||||
| task="output_dataset_construction", | ||||
| taskNumber=None, | ||||
| ) | ||||
| ) | ||||
|
|
||||
| return output_dataset, extraction_errors | ||||
|
|
||||
| def get_openlineage_facets_on_complete(self, task_instance): | ||||
| """ | ||||
| Compute OpenLineage facets for the COPY INTO command. | ||||
|
|
||||
| Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and build an | ||||
| input dataset list and an output dataset (the Delta table). | ||||
| """ | ||||
| from airflow.providers.common.compat.openlineage.facet import ExtractionErrorRunFacet | ||||
| from airflow.providers.openlineage.extractors import OperatorLineage | ||||
|
|
||||
| if not self._sql: | ||||
| self.log.warning("No SQL query found, returning empty OperatorLineage.") | ||||
| return OperatorLineage() | ||||
|
|
||||
| # Get input datasets and any parsing errors | ||||
| input_datasets, extraction_errors = self._parse_input_dataset() | ||||
|
|
||||
| # Create SQL job facet | ||||
| job_facets, sql_errors = self._create_sql_job_facet() | ||||
| extraction_errors.extend(sql_errors) | ||||
|
|
||||
| run_facets = {} | ||||
| if extraction_errors: | ||||
| run_facets["extractionError"] = ExtractionErrorRunFacet( | ||||
| totalTasks=1, | ||||
| failedTasks=len(extraction_errors), | ||||
| errors=extraction_errors, | ||||
| ) | ||||
| # Return only error facets for invalid URIs | ||||
| return OperatorLineage( | ||||
| inputs=[], | ||||
| outputs=[], | ||||
| job_facets=job_facets, | ||||
| run_facets=run_facets, | ||||
| ) | ||||
|
Comment on lines
+496
to
+508
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we try to return the output dataset even if the inputs are incorrect? |
||||
|
|
||||
| # Build output dataset | ||||
| output_dataset, output_errors = self._build_output_dataset() | ||||
| if output_errors: | ||||
| extraction_errors.extend(output_errors) | ||||
| run_facets["extractionError"] = ExtractionErrorRunFacet( | ||||
| totalTasks=1, | ||||
| failedTasks=len(extraction_errors), | ||||
| errors=extraction_errors, | ||||
| ) | ||||
|
|
||||
| return OperatorLineage( | ||||
| inputs=input_datasets, | ||||
| outputs=[output_dataset] if output_dataset else [], | ||||
| job_facets=job_facets, | ||||
| run_facets=run_facets, | ||||
| ) | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,17 @@ | |
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from unittest import mock | ||
|
|
||
| import pytest | ||
|
|
||
| from airflow.exceptions import AirflowException | ||
| from airflow.providers.common.compat.openlineage.facet import ( | ||
| Dataset, | ||
| SQLJobFacet, | ||
| ) | ||
| from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator | ||
| from airflow.providers.openlineage.extractors import OperatorLineage | ||
|
|
||
| DATE = "2017-04-20" | ||
| TASK_ID = "databricks-sql-operator" | ||
|
|
@@ -140,10 +147,8 @@ def test_copy_with_encryption_and_credential(): | |
| assert ( | ||
| op._create_sql_query() | ||
| == f"""COPY INTO test | ||
| FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """ | ||
| """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) | ||
| FILEFORMAT = CSV | ||
| """.strip() | ||
| FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) | ||
| FILEFORMAT = CSV""".strip() | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -253,3 +258,150 @@ def test_templating(create_task_instance_of_operator, session): | |
| assert task.files == "files" | ||
| assert task.table_name == "table-name" | ||
| assert task.databricks_conn_id == "databricks-conn-id" | ||
|
|
||
|
|
||
| @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") | ||
| def test_get_openlineage_facets_on_complete_s3(mock_hook): | ||
| """Test OpenLineage facets generation for S3 source.""" | ||
| mock_hook().run.return_value = [ | ||
| {"file": "s3://bucket/dir1/file1.csv"}, | ||
| {"file": "s3://bucket/dir1/file2.csv"}, | ||
| ] | ||
| mock_hook().get_connection().host = "databricks.com" | ||
|
|
||
| op = DatabricksCopyIntoOperator( | ||
| task_id="test", | ||
| table_name="schema.table", | ||
| file_location="s3://bucket/dir1", | ||
| file_format="CSV", | ||
| ) | ||
| op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'" | ||
| op._result = mock_hook().run.return_value | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is already gone from the operator, so the tests should be adjusted? |
||
|
|
||
| lineage = op.get_openlineage_facets_on_complete(None) | ||
|
|
||
| assert lineage == OperatorLineage( | ||
| inputs=[Dataset(namespace="s3://bucket", name="dir1")], | ||
| outputs=[Dataset(namespace="databricks://databricks.com", name="schema.table")], | ||
| job_facets={"sql": SQLJobFacet(query="COPY INTO schema.table FROM 's3://bucket/dir1'")}, | ||
| run_facets={}, | ||
| ) | ||
|
|
||
|
|
||
| @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") | ||
| def test_get_openlineage_facets_on_complete_with_errors(mock_hook): | ||
| """Test OpenLineage facets generation with extraction errors.""" | ||
| mock_hook().run.return_value = [ | ||
| {"file": "s3://bucket/dir1/file1.csv"}, | ||
| {"file": "invalid://location/file.csv"}, # Invalid URI | ||
| {"file": "azure://account.invalid.windows.net/container/file.csv"}, # Invalid Azure URI | ||
| ] | ||
| mock_hook().get_connection().host = "databricks.com" | ||
|
Comment on lines
+292
to
+299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we are passing invalid URI's and then checking that there are no extraction errors? Is this test valid? |
||
|
|
||
| op = DatabricksCopyIntoOperator( | ||
| task_id="test", | ||
| table_name="schema.table", | ||
| file_location="s3://bucket/dir1", | ||
| file_format="CSV", | ||
| ) | ||
| op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'" | ||
| op._result = mock_hook().run.return_value | ||
|
|
||
| lineage = op.get_openlineage_facets_on_complete(None) | ||
|
|
||
| # Check inputs and outputs | ||
| assert len(lineage.inputs) == 1 | ||
| assert lineage.inputs[0].namespace == "s3://bucket" | ||
| assert lineage.inputs[0].name == "dir1" | ||
|
|
||
| assert len(lineage.outputs) == 1 | ||
| assert lineage.outputs[0].namespace == "databricks://databricks.com" | ||
| assert lineage.outputs[0].name == "schema.table" | ||
|
|
||
| # Check facets exist and have correct structure | ||
| assert "sql" in lineage.job_facets | ||
| assert lineage.job_facets["sql"].query == "COPY INTO schema.table FROM 's3://bucket/dir1'" | ||
|
|
||
| assert "extractionError" not in lineage.run_facets | ||
|
|
||
|
|
||
| @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") | ||
| def test_get_openlineage_facets_on_complete_no_sql(mock_hook): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should run the execute and then explicitly overwrite the self._sql? Or at least manually make sure the _sql is None. This test assumes the self._sql is initiated as None, but we don't check it. |
||
| """Test OpenLineage facets generation when no SQL is available.""" | ||
| op = DatabricksCopyIntoOperator( | ||
| task_id="test", | ||
| table_name="schema.table", | ||
| file_location="s3://bucket/dir1", | ||
| file_format="CSV", | ||
| ) | ||
|
|
||
| lineage = op.get_openlineage_facets_on_complete(None) | ||
| assert lineage == OperatorLineage() | ||
|
|
||
|
|
||
| @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") | ||
| def test_get_openlineage_facets_on_complete_gcs(mock_hook): | ||
| """Test OpenLineage facets generation specifically for GCS paths.""" | ||
| mock_hook().run.return_value = [ | ||
| {"file": "gs://bucket1/dir1/file1.csv"}, | ||
| {"file": "gs://bucket1/dir2/nested/file2.csv"}, | ||
| {"file": "gs://bucket2/file3.csv"}, | ||
| {"file": "gs://bucket2"}, # Edge case: root path | ||
| {"file": "gs://invalid-bucket/@#$%"}, # Invalid path | ||
| ] | ||
| mock_hook().get_connection.return_value.host = "databricks.com" | ||
| mock_hook().query_ids = ["query_123"] | ||
|
|
||
| op = DatabricksCopyIntoOperator( | ||
| task_id="test", | ||
| table_name="catalog.schema.table", | ||
| file_location="gs://location", | ||
| file_format="CSV", | ||
| ) | ||
| op.execute(None) | ||
| result = op.get_openlineage_facets_on_complete(None) | ||
|
|
||
| # Check inputs - only one input from file_location | ||
| assert len(result.inputs) == 1 | ||
| assert result.inputs[0].namespace == "gs://location" | ||
| assert result.inputs[0].name == "/" | ||
|
|
||
| # Check outputs | ||
| assert len(result.outputs) == 1 | ||
| assert result.outputs[0].namespace == "databricks://databricks.com" | ||
| assert result.outputs[0].name == "catalog.schema.table" | ||
|
|
||
| # Check SQL job facet | ||
| assert "sql" in result.job_facets | ||
| assert "COPY INTO catalog.schema.table" in result.job_facets["sql"].query | ||
| assert "FILEFORMAT = CSV" in result.job_facets["sql"].query | ||
|
Comment on lines
+376
to
+377
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should check the whole query or at least also check if the gcs path is there? WDYT? |
||
|
|
||
|
|
||
| @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") | ||
| def test_get_openlineage_facets_on_complete_invalid_location(mock_hook): | ||
| """Test OpenLineage facets generation with invalid file_location.""" | ||
| mock_hook().get_connection().host = "databricks.com" | ||
|
|
||
| op = DatabricksCopyIntoOperator( | ||
| task_id="test", | ||
| table_name="schema.table", | ||
| file_location="invalid://location", # Invalid location | ||
| file_format="CSV", | ||
| ) | ||
| op._sql = "COPY INTO schema.table FROM 'invalid://location'" | ||
| op._result = [{"file": "s3://bucket/file.csv"}] | ||
|
Comment on lines
+391
to
+392
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not actually execute the operator instead? |
||
|
|
||
| lineage = op.get_openlineage_facets_on_complete(None) | ||
|
|
||
| # Should have no inputs due to invalid location | ||
| assert len(lineage.inputs) == 0 | ||
|
|
||
| # Should not have output and SQL facets | ||
| assert len(lineage.outputs) == 0 | ||
| assert "sql" in lineage.job_facets | ||
|
|
||
| # Should have extraction error facet | ||
| assert "extractionError" in lineage.run_facets | ||
| assert lineage.run_facets["extractionError"].totalTasks == 1 | ||
| assert lineage.run_facets["extractionError"].failedTasks == 1 | ||
| assert len(lineage.run_facets["extractionError"].errors) == 1 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is no longer needed?