Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/databricks/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Requirements
PIP package Version required
======================================= ==================
``apache-airflow`` ``>=2.10.0``
``apache-airflow-providers-common-sql`` ``>=1.21.0``
``apache-airflow-providers-common-sql`` ``>=1.27.0``
``requests`` ``>=2.31.0,<3``
``databricks-sql-connector`` ``>=3.0.0``
``aiohttp`` ``>=3.9.2,<4``
Expand Down
3 changes: 2 additions & 1 deletion providers/databricks/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ requires-python = "~=3.9"
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-compat>=1.6.0",
"apache-airflow-providers-common-sql>=1.21.0",
"apache-airflow-providers-common-sql>=1.27.0",
"requests>=2.31.0,<3",
"databricks-sql-connector>=3.0.0",
"databricks-sqlalchemy>=1.0.2",
Expand Down Expand Up @@ -101,6 +101,7 @@ dev = [
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"deltalake>=0.12.0",
"apache-airflow-providers-microsoft-azure",
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from unittest import mock
from unittest.mock import PropertyMock, patch

import pandas as pd
import polars as pl
import pytest
from databricks.sql.types import Row

Expand Down Expand Up @@ -504,3 +506,57 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider
)
with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err):
hook.get_openlineage_database_specific_lineage(mock.MagicMock())


@pytest.mark.parametrize(
"df_type, df_class, description",
[
pytest.param("pandas", pd.DataFrame, [(("col",))], id="pandas-dataframe"),
pytest.param(
"polars",
pl.DataFrame,
[(("col", None, None, None, None, None, None))],
id="polars-dataframe",
),
],
)
def test_get_df(df_type, df_class, description):
hook = DatabricksSqlHook()
statement = "SQL"
column = "col"
result_sets = [("row1",), ("row2",)]

with mock.patch(
"airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook.get_conn"
) as mock_get_conn:
if df_type == "pandas":
# Setup for pandas test case
mock_cursor = mock.MagicMock()
mock_cursor.description = description
mock_cursor.fetchall.return_value = result_sets
mock_get_conn.return_value.cursor.return_value = mock_cursor
else:
# Setup for polars test case
mock_execute = mock.MagicMock()
mock_execute.description = description
mock_execute.fetchall.return_value = result_sets

mock_cursor = mock.MagicMock()
mock_cursor.execute.return_value = mock_execute
mock_get_conn.return_value.cursor.return_value = mock_cursor

df = hook.get_df(statement, df_type=df_type)
mock_cursor.execute.assert_called_once_with(statement)

if df_type == "pandas":
mock_cursor.fetchall.assert_called_once_with()
assert df.columns[0] == column
assert df.iloc[0][0] == "row1"
assert df.iloc[1][0] == "row2"
else:
mock_execute.fetchall.assert_called_once_with()
assert df.columns[0] == column
assert df.row(0)[0] == result_sets[0][0]
assert df.row(1)[0] == result_sets[1][0]

assert isinstance(df, df_class)
Loading