Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify DbApiHook.run() method with the methods which override it #23971

Merged
merged 7 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def __init__(
follow_task_ids_if_false: List[str],
conn_id: str = "default_conn_id",
database: Optional[str] = None,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
*,
sql: Union[str, Iterable[str]],
redshift_conn_id: str = 'redshift_default',
parameters: Optional[dict] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
**kwargs,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
unload_options: Optional[List] = None,
autocommit: bool = False,
include_header: bool = False,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
**kwargs,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import warnings
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -140,7 +140,7 @@ def execute(self, context: 'Context') -> None:

copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)

sql: Union[list, str]
sql: Union[str, Iterable[str]]

if self.method == 'REPLACE':
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
Expand Down
8 changes: 2 additions & 6 deletions airflow/providers/apache/drill/operators/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
# under the License.
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

import sqlparse

from airflow.models import BaseOperator
from airflow.providers.apache.drill.hooks.drill import DrillHook

Expand Down Expand Up @@ -52,7 +50,7 @@ def __init__(
*,
sql: str,
drill_conn_id: str = 'drill_default',
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -64,6 +62,4 @@ def __init__(
def execute(self, context: 'Context'):
self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True))
no_term_sql = [s[:-1] for s in sql if s[-1] == ';']
self.hook.run(no_term_sql, parameters=self.parameters)
self.hook.run(self.sql, parameters=self.parameters, split_statements=True)
1 change: 0 additions & 1 deletion airflow/providers/apache/drill/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql
- sqlalchemy-drill>=1.1.0
- sqlparse>=0.4.1

integrations:
- integration-name: Apache Drill
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/pinot/hooks/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
import subprocess
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Iterable, List, Mapping, Optional, Union

from pinotdb import connect

Expand Down Expand Up @@ -275,7 +275,7 @@ def get_uri(self) -> str:
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'

def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns a set of records.

Expand All @@ -287,7 +287,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera
cur.execute(sql)
return cur.fetchall()

def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns the first resulting row.

Expand Down
68 changes: 54 additions & 14 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import warnings
from contextlib import closing
from datetime import datetime
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union

import sqlparse
potiuk marked this conversation as resolved.
Show resolved Hide resolved
from sqlalchemy import create_engine
from typing_extensions import Protocol

Expand All @@ -27,6 +28,17 @@
from airflow.providers_manager import ProvidersManager
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from sqlalchemy.engine import CursorResult


def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]:
"""Handler for DbApiHook.run() to return results"""
if cursor.returns_rows:
return cursor.fetchall()
else:
return None


def _backported_get_hook(connection, *, hook_params=None):
"""Return hook based on conn_type
Expand Down Expand Up @@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None):
cur.execute(sql)
return cur.fetchone()

def run(self, sql, autocommit=False, parameters=None, handler=None):
@staticmethod
def strip_sql_string(sql: str) -> str:
return sql.strip().rstrip(';')

@staticmethod
def split_sql_string(sql: str) -> List[str]:
"""
Splits string into multiple SQL expressions

:param sql: SQL string potentially consisting of multiple expressions
:return: list of individual expressions
"""
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
statements = [s.rstrip(';') for s in splits if s.endswith(';')]
return statements

def run(
self,
sql: Union[str, Iterable[str]],
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
split_statements: bool = False,
return_last: bool = True,
) -> Optional[Union[Any, List[Any]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:return: query results if handler was provided.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the ALL SQL expressions if handler was provided.
"""
scalar = isinstance(sql, str)
if scalar:
sql = [sql]
scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
potiuk marked this conversation as resolved.
Show resolved Hide resolved
if split_statements:
sql = self.split_sql_string(sql)
else:
sql = [self.strip_sql_string(sql)]

if sql:
self.log.debug("Executing %d statements", len(sql))
self.log.debug("Executing following statements against DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")

Expand All @@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
results = []
for sql_statement in sql:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
results.append(result)

# If autocommit was set to False for db that supports autocommit,
# or if db does not supports autocommit, we do a manual commit.
# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None

if scalar:
return results[0]

return results
elif scalar_return_last:
return results[-1]
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is backwards-incompatike. Currently when run gets a single SQL string, it return s single result, if it gets array of strings it returns array of results. In order to keep backwards compatibility we should retain this behaviour.

I think the reasonable approach will be to check if even after split we have one or multiple queries. If 1 -> return resullts[0] else results.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. I see that there are different behaviours in different operators. I think for now maybe we need another parameter (scalar_single_result: bool) ? this way we could keep full backwards compatibility in all the providers I think ?

Copy link
Contributor Author

@kazanzhy kazanzhy Jul 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the very first iteration it was return results[-1] which good for a single statement and list of them.
In one of our previous PRs, we got the idea to return all statement's results.
#23623 (comment)

In this case, we will have the next returning type Optional[List[HandlerResult]].
Before that, it was Optional[HandlerResult]. We could add parameter return_last=True and then we will have Optional[Union[HandlerResult, List[HandlerResult]]].
For example:

  • if no handler then -> None else:
  • if one statement and return_last=True then -> HandlerResult
  • if split_statements=True and return_last=True then -> HandlerResult
  • if split_statements=True and return_last=False then -> List[HandlerResult]

also for one statement we could return a list with single value
if one statement and return_last=False then -> List[HandlerResult]

In code, I assume it'll be

if handler is None:
    return None
elif return_last:
    return results[-1]
else:
    return results

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_last flag seems reasonable.

I do like to point to my note #23623 (comment) but that ship has sailed.

return results

def _run_command(self, cur, sql_statement, parameters):
"""Runs a statement using an already open cursor."""
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ description: |
versions:
- 1.0.0

dependencies: []
dependencies:
- sqlparse>=0.4.2

additional-extras:
- name: pandas
Expand Down
76 changes: 37 additions & 39 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import re
from contextlib import closing
from copy import copy
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
Expand Down Expand Up @@ -139,19 +138,15 @@ def get_conn(self) -> Connection:
)
return self._sql_conn

@staticmethod
def maybe_split_sql_string(sql: str) -> List[str]:
"""
Splits strings consisting of multiple SQL expressions into an
TODO: do we need something more sophisticated?

:param sql: SQL string potentially consisting of multiple expressions
:return: list of individual expressions
"""
splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""]
return splits

def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None):
def run(
self,
sql: Union[str, Iterable[str]],
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
split_statements: bool = True,
return_last: bool = True,
) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -163,41 +158,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:return: query results.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
sql = self.maybe_split_sql_string(sql)
if split_statements:
sql = self.split_sql_string(sql)
else:
sql = [self.strip_sql_string(sql)]

if sql:
self.log.debug("Executing %d statements", len(sql))
self.log.debug("Executing following statements against Databricks DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")

conn = None
results = []
for sql_statement in sql:
# when using AAD tokens, it could expire if previous query run longer than token lifetime
conn = self.get_conn()
with closing(conn.cursor()) as cur:
self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)
schema = cur.description
results = []
if handler is not None:
cur = handler(cur)
for row in cur:
self.log.debug("Statement results: %s", row)
results.append(row)

self.log.info("Rows affected: %s", cur.rowcount)
if conn:
conn.close()
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)

with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
schema = cur.description
results.append((schema, result))

self._sql_conn = None

# Return only result of the last SQL expression
return schema, results
if handler is None:
return None
elif scalar_return_last:
return results[-1]
else:
return results

def test_connection(self):
"""Test the Databricks SQL connection by running a simple query."""
Expand Down
12 changes: 7 additions & 5 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

import csv
import json
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast

from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,11 +72,11 @@ class DatabricksSqlOperator(BaseOperator):
def __init__(
self,
*,
sql: Union[str, List[str]],
sql: Union[str, Iterable[str]],
databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
http_path: Optional[str] = None,
sql_endpoint_name: Optional[str] = None,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
session_configuration=None,
http_headers: Optional[List[Tuple[str, str]]] = None,
catalog: Optional[str] = None,
Expand Down Expand Up @@ -147,10 +148,11 @@ def _format_output(self, schema, results):
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")

def execute(self, context: 'Context') -> Any:
def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = self._get_hook()
schema, results = hook.run(self.sql, parameters=self.parameters)
response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler)
schema, results = cast(List[Tuple[Any, Any]], response)[0]
# self.log.info('Schema: %s', schema)
# self.log.info('Results: %s', results)
self._format_output(schema, results)
Expand Down
Loading