Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make conn id parameters templated in GenericTransfer and also allow passing hook parameters like in BaseSQLOperator #42891

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
63e450f
refactored: Added hook_params to get_hook BaseHook method, templated …
davidblain-infrabel Oct 10, 2024
d45102d
refactored: Added unit test to check when to remove obsolete code in …
davidblain-infrabel Oct 10, 2024
a44217d
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 10, 2024
80b1486
refactored: Reformatted TestGenericTransfer
davidblain-infrabel Oct 10, 2024
1ee071f
refactored: Reformatted GenericTransfer
davidblain-infrabel Oct 10, 2024
626b54b
refactored: Don't need to test min Airflow version for GenericTransfe…
davidblain-infrabel Oct 10, 2024
dae835c
refactored: Don't need to redefine get_hook method with hook_params i…
davidblain-infrabel Oct 10, 2024
d868cc1
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 10, 2024
ab29b0a
refactored: Reorganized imports test generic transfer
davidblain-infrabel Oct 10, 2024
f36cd3d
refactored: Fixed import of test_utils in test_dag_run
davidblain-infrabel Oct 10, 2024
c7b7407
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 11, 2024
47b9180
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 11, 2024
a59065d
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 14, 2024
598acfd
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 15, 2024
756ae98
refactor: Moved GenericTransfer to standard operators provider
davidblain-infrabel Oct 15, 2024
81ea259
refactor: Removed duplicate imports of AIRFLOW version constants
davidblain-infrabel Oct 15, 2024
611387e
refactor: Standard provider should be dependant of apache-airflow-pro…
davidblain-infrabel Oct 15, 2024
9a0de7f
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 15, 2024
71cd36f
refactor: Updated provider dependencies
dabla Oct 15, 2024
d0ea2be
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 15, 2024
e3ba87a
refactor: Reformatted TestGenericTransfer
davidblain-infrabel Oct 16, 2024
1a23c7f
refactor: Added generic transfer operator in standard operator provid…
davidblain-infrabel Oct 16, 2024
7626d54
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 16, 2024
ac5f210
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 16, 2024
3450f4f
refactor: Fixed override of get_hook method in GenericTransfer
davidblain-infrabel Oct 16, 2024
ca16162
refactor: Fixed import of get_provider_min_airflow_version in test sq…
davidblain-infrabel Oct 16, 2024
e0ef9b3
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 16, 2024
82f8f52
refactor: Removed whiteline in TestBaseSQLOperator
davidblain-infrabel Oct 16, 2024
1ac62cc
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 16, 2024
3f96cd0
refactor: Reorganized imports TestGenericTransfer
davidblain-infrabel Oct 16, 2024
7114d97
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 16, 2024
b1a061f
refactor: Reorganized imports TestBaseSQLOperator
davidblain-infrabel Oct 17, 2024
c0df82e
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 17, 2024
332d68d
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 18, 2024
6654b2c
Merge branch 'main' into feature/templated_conn_id_generic_transfer
dabla Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions airflow/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def get_connection(cls, conn_id: str) -> Connection:
return conn

@classmethod
def get_hook(cls, conn_id: str) -> BaseHook:
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook:
"""
Return default hook for this connection id.

:param conn_id: connection id
:param hook_params: hook parameters
:return: default hook for this connection
"""
connection = cls.get_connection(conn_id)
return connection.get_hook()
return connection.get_hook(hook_params=hook_params)

def get_conn(self) -> Any:
"""Return connection for the hook."""
Expand Down
22 changes: 17 additions & 5 deletions airflow/operators/generic_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ class GenericTransfer(BaseOperator):

:param sql: SQL query to execute against the source database. (templated)
:param destination_table: target table. (templated)
:param source_conn_id: source connection
:param destination_conn_id: destination connection
:param source_conn_id: source connection. (templated)
:param destination_conn_id: destination connection. (templated)
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
"""

template_fields: Sequence[str] = ("sql", "destination_table", "preoperator")
template_fields: Sequence[str] = (
"source_conn_id",
"destination_conn_id",
"sql",
"destination_table",
"preoperator",
)
template_ext: Sequence[str] = (
".sql",
".hql",
Expand All @@ -59,7 +65,9 @@ def __init__(
sql: str,
destination_table: str,
source_conn_id: str,
source_hook_params: dict | None = None,
destination_conn_id: str,
destination_hook_params: dict | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
**kwargs,
Expand All @@ -68,13 +76,17 @@ def __init__(
self.sql = sql
self.destination_table = destination_table
self.source_conn_id = source_conn_id
self.source_hook_params = source_hook_params
self.destination_conn_id = destination_conn_id
self.destination_hook_params = destination_hook_params
self.preoperator = preoperator
self.insert_args = insert_args or {}

def execute(self, context: Context):
source_hook = BaseHook.get_hook(self.source_conn_id)
destination_hook = BaseHook.get_hook(self.destination_conn_id)
source_hook = BaseHook.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params)
dabla marked this conversation as resolved.
Show resolved Hide resolved
destination_hook = BaseHook.get_hook(
conn_id=self.destination_conn_id, hook_params=self.destination_hook_params
)

self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)
Expand Down
16 changes: 14 additions & 2 deletions providers/src/airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,25 @@ def __init__(
self.hook_params = hook_params or {}
self.retry_on_failure = retry_on_failure

@classmethod
dabla marked this conversation as resolved.
Show resolved Hide resolved
# TODO: can be removed once Airflow min version for this provider is 3.0.0 or higher
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook:
"""
Return default hook for this connection id.

:param conn_id: connection id
:param hook_params: hook parameters
:return: default hook for this connection
"""
connection = BaseHook.get_connection(conn_id)
return connection.get_hook(hook_params=hook_params)

@cached_property
def _hook(self):
"""Get DB Hook based on connection type."""
conn_id = getattr(self, self.conn_id_field)
self.log.debug("Get connection for %s", conn_id)
conn = BaseHook.get_connection(conn_id)
hook = conn.get_hook(hook_params=self.hook_params)
hook = self.get_hook(conn_id=conn_id, hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
raise AirflowException(
f"You are trying to use `common-sql` with {hook.__class__.__name__},"
Expand Down
21 changes: 20 additions & 1 deletion providers/tests/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import annotations

import datetime
import inspect
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow import DAG
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
Expand All @@ -45,6 +46,7 @@
from airflow.utils.state import State

from dev.tests_common.test_utils.compat import AIRFLOW_V_2_8_PLUS, AIRFLOW_V_3_0_PLUS
from dev.tests_common.test_utils.providers import get_provider_min_airflow_version

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
Expand Down Expand Up @@ -91,6 +93,23 @@ def test_templated_fields(self):
assert operator.database == "my_database"
assert operator.hook_params == {"key": "value"}

def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_get_hook_method(self):
"""
Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher
for this provider, you should remove the obsolete get_hook method in the BaseSQLOperator operator
and remove this test. This test was added to make sure to not forget to remove the fallback code
for backward compatibility with Airflow 2.8.x which isn't need anymore once this provider depends on
Airflow 3.0.0 or higher.
"""
min_airflow_version = get_provider_min_airflow_version("apache-airflow-providers-common-sql")

# Check if the current Airflow version is 3.0.0 or higher
if min_airflow_version[0] >= 3:
method_source = inspect.getsource(BaseSQLOperator.get_hook)
raise AirflowProviderDeprecationWarning(
f"Check TODO's to remove obsolete get_hook method in BaseSQLOperator:\n\r\n\r\t\t\t{method_source}"
)


class TestSQLExecuteQueryOperator:
def _construct_operator(self, sql, **kwargs):
Expand Down
34 changes: 34 additions & 0 deletions tests/operators/test_generic_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from contextlib import closing
from datetime import datetime
from unittest import mock

import pytest
Expand Down Expand Up @@ -151,3 +152,36 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):
assert mock_insert.called
_, kwargs = mock_insert.call_args
assert "replace" in kwargs


class TestGenericTransfer:
def test_templated_fields(self):
dag = DAG(
"test_dag",
schedule=None,
start_date=datetime(2024, 10, 10),
render_template_as_native_obj=True,
)
operator = GenericTransfer(
task_id="test_task",
sql="{{ sql }}",
destination_table="{{ destination_table }}",
source_conn_id="{{ source_conn_id }}",
destination_conn_id="{{ destination_conn_id }}",
preoperator="{{ preoperator }}",
dag=dag,
)
operator.render_template_fields(
{
"sql": "my_sql",
"destination_table": "my_destination_table",
"source_conn_id": "my_source_conn_id",
"destination_conn_id": "my_destination_conn_id",
"preoperator": "my_preoperator",
}
)
assert operator.sql == "my_sql"
assert operator.destination_table == "my_destination_table"
assert operator.source_conn_id == "my_source_conn_id"
assert operator.destination_conn_id == "my_destination_conn_id"
assert operator.preoperator == "my_preoperator"