Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def test_expected_output_push(
),
{
"selected-providers-list-as-string": "amazon apache.beam apache.cassandra apache.kafka "
"cncf.kubernetes common.compat common.sql "
"cncf.kubernetes common.compat common.sql databricks "
"facebook google hashicorp http microsoft.azure microsoft.mssql mysql "
"openlineage oracle postgres presto salesforce samba sftp ssh standard trino",
"all-python-versions": f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
Expand All @@ -1960,7 +1960,7 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"docs-build": "true",
"docs-list-as-string": "apache-airflow helm-chart amazon apache.beam apache.cassandra "
"apache.kafka cncf.kubernetes common.compat common.sql facebook google hashicorp http microsoft.azure "
"apache.kafka cncf.kubernetes common.compat common.sql databricks facebook google hashicorp http microsoft.azure "
"microsoft.mssql mysql openlineage oracle postgres "
"presto salesforce samba sftp ssh standard trino",
"skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
Expand All @@ -1974,7 +1974,7 @@ def test_expected_output_push(
{
"description": "amazon...standard",
"test_types": "Providers[amazon] Providers[apache.beam,apache.cassandra,"
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
"apache.kafka,cncf.kubernetes,common.compat,common.sql,databricks,facebook,"
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
"Providers[google] "
Expand Down Expand Up @@ -2245,7 +2245,7 @@ def test_upgrade_to_newer_dependencies(
("providers/google/docs/some_file.rst",),
{
"docs-list-as-string": "amazon apache.beam apache.cassandra apache.kafka "
"cncf.kubernetes common.compat common.sql facebook google hashicorp http "
"cncf.kubernetes common.compat common.sql databricks facebook google hashicorp http "
"microsoft.azure microsoft.mssql mysql openlineage oracle "
"postgres presto salesforce samba sftp ssh standard trino",
},
Expand Down
1 change: 1 addition & 0 deletions providers/databricks/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Dependent package
================================================================================================================== =================
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
`apache-airflow-providers-google <https://airflow.apache.org/docs/apache-airflow-providers-google>`_ ``google``
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
================================================================================================================== =================

Expand Down
7 changes: 7 additions & 0 deletions providers/databricks/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ dependencies = [
"sqlalchemy" = [
"databricks-sqlalchemy>=1.0.2",
]
"google" = [
"apache-airflow-providers-google>=10.24.0"
]
"avro" = [
"fastavro>=1.9.0"
]

[dependency-groups]
dev = [
Expand All @@ -101,6 +107,7 @@ dev = [
"apache-airflow-devel-common",
"apache-airflow-providers-common-compat",
"apache-airflow-providers-common-sql",
"apache-airflow-providers-google",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
# Need to exclude 1.3.0 due to missing aarch64 binaries, fixed with 1.3.1++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@

import csv
import json
import os
from collections.abc import Sequence
from functools import cached_property
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, ClassVar
from urllib.parse import urlparse

from databricks.sql.utils import ParamEscaper

from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
from airflow.providers.common.compat.sdk import (
AirflowException,
AirflowOptionalProviderFeatureException,
BaseOperator,
)
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

Expand Down Expand Up @@ -62,13 +69,27 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+ (templated)
:param schema: An optional initial schema to use. Requires DBR version 9.0+ (templated)
:param output_path: optional string specifying the file to which write selected data. (templated)
:param output_format: format of output data if ``output_path` is specified.
Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
Supports local file paths and GCS URIs (e.g., ``gs://bucket/path/file.parquet``).
When using GCS URIs, requires the ``apache-airflow-providers-google`` package.
:param output_format: format of output data if ``output_path`` is specified.
Possible values are ``csv``, ``json``, ``jsonl``, ``parquet``, ``avro``. Default is ``csv``.
:param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
:param gcp_conn_id: The connection ID to use for connecting to Google Cloud when using GCS output path.
Default is ``google_cloud_default``.
:param gcs_impersonation_chain: Optional service account to impersonate using short-term
credentials for GCS upload, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request. (templated)
"""

template_fields: Sequence[str] = tuple(
{"_output_path", "schema", "catalog", "http_headers", "databricks_conn_id"}
{
"_output_path",
"schema",
"catalog",
"http_headers",
"databricks_conn_id",
"_gcs_impersonation_chain",
}
| set(SQLExecuteQueryOperator.template_fields)
)

Expand All @@ -90,6 +111,8 @@ def __init__(
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
client_parameters: dict[str, Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
gcs_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=databricks_conn_id, **kwargs)
Expand All @@ -105,6 +128,8 @@ def __init__(
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self._gcp_conn_id = gcp_conn_id
self._gcs_impersonation_chain = gcs_impersonation_chain

@cached_property
def _hook(self) -> DatabricksSqlHook:
Expand All @@ -127,41 +152,151 @@ def get_db_hook(self) -> DatabricksSqlHook:
def _should_run_output_processing(self) -> bool:
return self.do_xcom_push or bool(self._output_path)

@property
def _is_gcs_output(self) -> bool:
"""Check if the output path is a GCS URI."""
return self._output_path.startswith("gs://") if self._output_path else False

def _parse_gcs_path(self, path: str) -> tuple[str, str]:
"""Parse a GCS URI into bucket and object name."""
parsed = urlparse(path)
bucket = parsed.netloc
object_name = parsed.path.lstrip("/")
return bucket, object_name

def _upload_to_gcs(self, local_path: str, gcs_path: str) -> None:
"""Upload a local file to GCS."""
try:
from airflow.providers.google.cloud.hooks.gcs import GCSHook
except ImportError:
raise AirflowOptionalProviderFeatureException(
"The 'apache-airflow-providers-google' package is required for GCS output. "
"Install it with: pip install apache-airflow-providers-google"
)

bucket, object_name = self._parse_gcs_path(gcs_path)
hook = GCSHook(
gcp_conn_id=self._gcp_conn_id,
impersonation_chain=self._gcs_impersonation_chain,
)
hook.upload(
bucket_name=bucket,
object_name=object_name,
filename=local_path,
)
self.log.info("Uploaded output to %s", gcs_path)

def _write_parquet(self, file_path: str, field_names: list[str], rows: list[Any]) -> None:
"""Write data to a Parquet file."""
import pyarrow as pa
import pyarrow.parquet as pq

data: dict[str, list] = {name: [] for name in field_names}
for row in rows:
row_dict = row._asdict()
for name in field_names:
data[name].append(row_dict[name])

table = pa.Table.from_pydict(data)
pq.write_table(table, file_path)

def _write_avro(self, file_path: str, field_names: list[str], rows: list[Any]) -> None:
"""Write data to an Avro file using fastavro."""
try:
from fastavro import writer
except ImportError:
raise AirflowOptionalProviderFeatureException(
"The 'fastavro' package is required for Avro output. Install it with: pip install fastavro"
)

data: dict[str, list] = {name: [] for name in field_names}
for row in rows:
row_dict = row._asdict()
for name in field_names:
data[name].append(row_dict[name])

schema_fields = []
for name in field_names:
sample_val = next(
(data[name][i] for i in range(len(data[name])) if data[name][i] is not None), None
)
if sample_val is None:
avro_type = ["null", "string"]
elif isinstance(sample_val, bool):
avro_type = ["null", "boolean"]
elif isinstance(sample_val, int):
avro_type = ["null", "long"]
elif isinstance(sample_val, float):
avro_type = ["null", "double"]
else:
avro_type = ["null", "string"]
schema_fields.append({"name": name, "type": avro_type})

avro_schema = {
"type": "record",
"name": "QueryResult",
"fields": schema_fields,
}

records = [row._asdict() for row in rows]
with open(file_path, "wb") as f:
writer(f, avro_schema, records)

def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
return list(zip(descriptions, results))
if not self._output_format:
raise AirflowException("Output format should be specified!")
# Output to a file only the result of last query

last_description = descriptions[-1]
last_results = results[-1]
if last_description is None:
raise AirflowException("There is missing description present for the output file. .")
raise AirflowException("There is missing description present for the output file.")
field_names = [field[0] for field in last_description]
if self._output_format.lower() == "csv":
with open(self._output_path, "w", newline="") as file:
if self._csv_params:
csv_params = self._csv_params
else:
csv_params = {}
write_header = csv_params.get("header", True)
if "header" in csv_params:
del csv_params["header"]
writer = csv.DictWriter(file, fieldnames=field_names, **csv_params)
if write_header:
writer.writeheader()
for row in last_results:
writer.writerow(row._asdict())
elif self._output_format.lower() == "json":
with open(self._output_path, "w") as file:
file.write(json.dumps([row._asdict() for row in last_results]))
elif self._output_format.lower() == "jsonl":
with open(self._output_path, "w") as file:
for row in last_results:
file.write(json.dumps(row._asdict()))
file.write("\n")

if self._is_gcs_output:
suffix = f".{self._output_format.lower()}"
tmp_file = NamedTemporaryFile(mode="w", suffix=suffix, delete=False, newline="")
local_path = tmp_file.name
tmp_file.close()
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")
local_path = self._output_path

try:
output_format = self._output_format.lower()
if output_format == "csv":
with open(local_path, "w", newline="") as file:
if self._csv_params:
csv_params = self._csv_params.copy()
else:
csv_params = {}
write_header = csv_params.pop("header", True)
writer = csv.DictWriter(file, fieldnames=field_names, **csv_params)
if write_header:
writer.writeheader()
for row in last_results:
writer.writerow(row._asdict())
elif output_format == "json":
with open(local_path, "w") as file:
file.write(json.dumps([row._asdict() for row in last_results]))
elif output_format == "jsonl":
with open(local_path, "w") as file:
for row in last_results:
file.write(json.dumps(row._asdict()))
file.write("\n")
elif output_format == "parquet":
self._write_parquet(local_path, field_names, last_results)
elif output_format == "avro":
self._write_avro(local_path, field_names, last_results)
else:
raise ValueError(f"Unsupported output format: '{self._output_format}'")

if self._is_gcs_output:
self._upload_to_gcs(local_path, self._output_path)
finally:
if self._is_gcs_output and os.path.exists(local_path):
os.unlink(local_path)

return list(zip(descriptions, results))


Expand Down
Loading
Loading