Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1dc6553
feat: Databricks sql hook returns json-serializable namedtuple
joffreybienvenu-infrabel Dec 13, 2023
6a35e30
fix: Ignore mypy
joffreybienvenu-infrabel Dec 13, 2023
e09a34b
fix: correctly describe serialization
joffreybienvenu-infrabel Dec 13, 2023
b11f883
feat: disable the return of serializable namedtuple by default
joffreybienvenu-infrabel Dec 13, 2023
837d460
fix: make DatabricksSqlOperator return serializable object by default
joffreybienvenu-infrabel Dec 13, 2023
c653798
fix: add `return_serializable` to databricks operator tests
joffreybienvenu-infrabel Dec 13, 2023
7c00732
fix: docstring spellchecks
joffreybienvenu-infrabel Dec 13, 2023
eceec0b
fix: rewrite docstring and warning message
joffreybienvenu-infrabel Dec 14, 2023
67d1b90
fix: docstring spellchecks
joffreybienvenu-infrabel Dec 14, 2023
1e84a75
feat: rename `make_serializable` into `make_common_data_structure` fo…
joffreybienvenu-infrabel Dec 16, 2023
e76d5b8
feat: Add typing for `_make_common_data_structure` in databricks
joffreybienvenu-infrabel Dec 16, 2023
a82553e
feat: Apply typing to ODBC hook
joffreybienvenu-infrabel Dec 16, 2023
d3510ef
fix: patch pyodbc.Row to make isinstance() checks pass with row_mock …
joffreybienvenu-infrabel Dec 17, 2023
0f20dfb
fix: Use List in type casting for python38 compatibility
joffreybienvenu-infrabel Dec 17, 2023
ab87191
feat: Remove all mentions of serialization in databricks hook
joffreybienvenu-infrabel Dec 17, 2023
10e22c7
feat: Rename `_make_serializable` into `_make_common_data_structure` …
joffreybienvenu-infrabel Dec 18, 2023
4796882
fix: use `return_tuple` in databricks tests
joffreybienvenu-infrabel Dec 18, 2023
61255ed
fix: add deprecated _make_serializable method in odbc.py
joffreybienvenu-infrabel Dec 18, 2023
e54ca54
bump: min sql.common to 1.9.1 for odbc and databricks
joffreybienvenu-infrabel Dec 18, 2023
e0c7ad4
fix: add version 1.9.1 to provider.yaml of common.sql
joffreybienvenu-infrabel Dec 18, 2023
74a763d
fix: move back-compat in common.sql
joffreybienvenu-infrabel Dec 18, 2023
affd2f4
feat: Implement strict `tuple`/`list[tuple]` return in ODBC and Datab…
joffreybienvenu-infrabel Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ repos:
files: ^dev/breeze/src/airflow_breeze/utils/docker_command_utils\.py$|^scripts/ci/docker_compose/local\.yml$
pass_filenames: false
additional_dependencies: ['rich>=12.4.4']
- id: check-common-sql-dependency-make-serializable
name: Check dependency of SQL Providers with '_make_serializable'
- id: check-sql-dependency-common-data-structure
name: Check dependency of SQL Providers with common data structure
entry: ./scripts/ci/pre_commit/pre_commit_check_common_sql_dependency.py
language: python
files: ^airflow/providers/.*/hooks/.*\.py$
Expand Down
4 changes: 2 additions & 2 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-cncf-k8s-only-for-executors | Check cncf.kubernetes imports used for executors only | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-common-sql-dependency-make-serializable | Check dependency of SQL Providers with '_make_serializable' | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-core-deprecation-classes | Verify usage of Airflow deprecation classes in core | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
Expand Down Expand Up @@ -240,6 +238,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-setup-order | Check order of dependencies in setup.cfg and setup.py | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-sql-dependency-common-data-structure | Check dependency of SQL Providers with common data structure | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-start-date-not-used-in-defaults | start_date not to be defined in default_args in example_dags | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-system-tests-present | Check if system tests have required segments of code | |
Expand Down
49 changes: 32 additions & 17 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import contextlib
import warnings
from contextlib import closing
from datetime import datetime
from typing import (
Expand All @@ -24,6 +26,7 @@
Callable,
Generator,
Iterable,
List,
Mapping,
Protocol,
Sequence,
Expand All @@ -36,7 +39,7 @@
import sqlparse
from sqlalchemy import create_engine

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,10 +125,10 @@ class DbApiHook(BaseHook):
"""
Abstract base class for sql hooks.

When subclassing, maintainers can override the `_make_serializable` method:
When subclassing, maintainers can override the `_make_common_data_structure` method:
This method transforms the result of the handler method (typically `cursor.fetchall()`) into
JSON-serializable objects. Most of the time, the underlying SQL library already returns tuples from
its cursor, and the `_make_serializable` method can be ignored.
objects common across all Hooks derived from this class (tuples). Most of the time, the underlying SQL
library already returns tuples from its cursor, and the `_make_common_data_structure` method can be ignored.

:param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
if you change the schema parameter value in the constructor of the derived Hook, such change
Expand Down Expand Up @@ -305,7 +308,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
...

def run(
Expand All @@ -316,7 +319,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""Run a command or a list of commands.

Pass a list of SQL statements to the sql parameter to get them to
Expand Down Expand Up @@ -392,7 +395,7 @@ def run(
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = self._make_serializable(handler(cur))
result = self._make_common_data_structure(handler(cur))
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
Expand All @@ -412,19 +415,31 @@ def run(
else:
return results

@staticmethod
def _make_serializable(result: Any) -> Any:
"""Ensure the data returned from an SQL command is JSON-serializable.
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple | list[tuple]:
"""Ensure the data returned from an SQL command is a standard tuple or list[tuple].

This method is intended to be overridden by subclasses of the `DbApiHook`. Its purpose is to
transform the result of an SQL command (typically returned by cursor methods) into a
JSON-serializable format.
transform the result of an SQL command (typically returned by cursor methods) into a common
data structure (a tuple or list[tuple]) across all DBApiHook derived Hooks, as defined in the
ADR-0002 of the sql provider.

If this method is not overridden, the result data is returned as-is. If the output of the cursor
is already a common data structure, this method should be ignored.
"""
# Back-compatibility call for providers implementing old ´_make_serializable' method.
with contextlib.suppress(AttributeError):
result = self._make_serializable(result=result) # type: ignore[attr-defined]
warnings.warn(
"The `_make_serializable` method is deprecated and support will be removed in a future "
f"version of the common.sql provider. Please update the {self.__class__.__name__}'s provider "
"to a version based on common.sql >= 1.9.1.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

If this method is not overridden, the result data is returned as-is.
If the output of the cursor is already JSON-serializable, this method
should be ignored.
"""
return result
if isinstance(result, Sequence):
return cast(List[tuple], result)
return cast(tuple, result)

def _run_command(self, cur, sql_statement, parameters):
"""Run a statement using an already open cursor."""
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ description: |
suspended: false
source-date-epoch: 1701983370
versions:
- 1.9.1
- 1.9.0
- 1.8.1
- 1.8.0
Expand Down
65 changes: 52 additions & 13 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,32 @@
# under the License.
from __future__ import annotations

import warnings
from collections import namedtuple
from contextlib import closing
from copy import copy
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Mapping,
Sequence,
TypeVar,
cast,
overload,
)

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.types import Row

LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")

Expand All @@ -52,6 +65,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
on every request
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param return_tuple: Return a ``namedtuple`` object instead of a ``databricks.sql.Row`` object. Default
to False. In a future release of the provider, this will become True by default. This parameter
ensures backward-compatibility during the transition phase to common tuple objects for all hooks based
on DbApiHook. This flag will also be removed in a future release.
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
"""

Expand All @@ -68,6 +85,7 @@ def __init__(
catalog: str | None = None,
schema: str | None = None,
caller: str = "DatabricksSqlHook",
return_tuple: bool = False,
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
Expand All @@ -80,8 +98,18 @@ def __init__(
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self.return_tuple = return_tuple
self.additional_params = kwargs

if not self.return_tuple:
warnings.warn(
"""Returning a raw `databricks.sql.Row` object is deprecated. A namedtuple will be
returned instead in a future release of the databricks provider. Set `return_tuple=True` to
enable this behavior.""",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
for arg in ["http_path", "session_configuration", *self.extra_parameters]:
Expand Down Expand Up @@ -167,7 +195,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
Copy link
Contributor Author

@Joffreybvn Joffreybvn Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy was not happy with -> tuple | list[tuple] | list[list[tuple] | tuple] as output - which is correct. I got error: "Overloaded function signatures 1 and 2 overlap with incompatible return types".

Following mypy docs, I added a None for it to be stop complaining. But that's not correct, this method cannot return None when a handler is provided...

Implementing definitive typing is out of the scope of this PR, and I admit mypy beat me on that case. Thus I let that for #36224 - but I'm whiling to help on that issue too of course.

...

def run(
Expand All @@ -178,7 +206,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""
Run a command or a list of commands.

Expand Down Expand Up @@ -223,7 +251,12 @@ def run(
with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)
if handler is not None:
result = self._make_serializable(handler(cur))
raw_result = handler(cur)
if self.return_tuple:
result = self._make_common_data_structure(raw_result)
else:
# Returning raw result is deprecated, and do not comply with current common.sql interface
result = raw_result # type: ignore[assignment]
if return_single_query_results(sql, return_last, split_statements):
results = [result]
self.descriptions = [cur.description]
Expand All @@ -241,14 +274,20 @@ def run(
else:
return results

@staticmethod
def _make_serializable(result):
"""Transform the databricks Row objects into JSON-serializable lists."""
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
return [list(row) for row in result]
elif isinstance(result, Row):
return list(result)
return result
rows: list[Row] = result
rows_fields = rows[0].__fields__
rows_object = namedtuple("Row", rows_fields) # type: ignore[misc]
return cast(List[tuple], [rows_object(*row) for row in rows])
else:
row: Row = result
row_fields = row.__fields__
row_object = namedtuple("Row", row_fields) # type: ignore[misc]
return cast(tuple, row_object(*row))

def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_db_hook(self) -> DatabricksSqlHook:
"catalog": self.catalog,
"schema": self.schema,
"caller": "DatabricksSqlOperator",
"return_tuple": True,
**self.client_parameters,
**self.hook_params,
}
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.8.1
- apache-airflow-providers-common-sql>=1.9.1
- requests>=2.27,<3
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
...

def run(
Expand All @@ -194,7 +194,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> T | list[T] | None:
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""Run a command or a list of commands.

Pass a list of SQL statements to the SQL parameter to get them to
Expand Down Expand Up @@ -232,7 +232,7 @@ def run(
with closing(conn.execute(sql_statement, parameters)) as exa_statement:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
result = handler(exa_statement)
result = self._make_common_data_structure(handler(exa_statement))
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_columns = self.get_description(exa_statement)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/exasol/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.3.1
- apache-airflow-providers-common-sql>=1.9.1
- pyexasol>=0.5.1
- pandas>=0.17.1

Expand Down
32 changes: 15 additions & 17 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""This module contains ODBC hook."""
from __future__ import annotations

from typing import Any, NamedTuple
from typing import Any, List, NamedTuple, Sequence, cast
from urllib.parse import quote_plus

import pyodbc
from pyodbc import Connection, Row, connect

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.helpers import merge_dicts
Expand Down Expand Up @@ -193,9 +193,9 @@ def connect_kwargs(self) -> dict:

return merged_connect_kwargs

def get_conn(self) -> pyodbc.Connection:
def get_conn(self) -> Connection:
"""Returns a pyodbc connection object."""
conn = pyodbc.connect(self.odbc_connection_string, **self.connect_kwargs)
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
return conn

def get_uri(self) -> str:
Expand All @@ -212,17 +212,15 @@ def get_sqlalchemy_connection(
cnx = engine.connect(**(connect_kwargs or {}))
return cnx

@staticmethod
def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) -> list[NamedTuple] | None:
"""Transform the pyodbc.Row objects returned from an SQL command into JSON-serializable NamedTuple."""
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
"""Transform the pyodbc.Row objects returned from an SQL command into typed NamedTuples."""
# Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically
# instantiated Namedtuple, and will never do: https://github.com/python/mypy/issues/848
columns: list[tuple[str, type]] | None = None
if isinstance(result, list):
columns = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", columns) # type: ignore[misc]
return [row_object(*row) for row in result]
elif isinstance(result, pyodbc.Row):
columns = [col[:2] for col in result.cursor_description]
return NamedTuple("Row", columns)(*result) # type: ignore[misc, operator]
return result
# instantiated typed Namedtuple, and will never do: https://github.com/python/mypy/issues/848
field_names: list[tuple[str, type]] | None = None
if isinstance(result, Sequence):
field_names = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", field_names) # type: ignore[misc]
return cast(List[tuple], [row_object(*row) for row in result])
else:
field_names = [col[:2] for col in result.cursor_description]
return cast(tuple, NamedTuple("Row", field_names)(*result)) # type: ignore[misc, operator]
2 changes: 1 addition & 1 deletion airflow/providers/odbc/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ versions:

dependencies:
- apache-airflow>=2.6.0
- apache-airflow-providers-common-sql>=1.8.1
- apache-airflow-providers-common-sql>=1.9.1
- pyodbc

integrations:
Expand Down
Loading