Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def execute(self, context: Context):
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)

results = self.destination_hook.get_records(self.sql)
results = self.source_hook.get_records(self.sql)

self.log.info("Inserting rows into %s", self.destination_conn_id)
self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from unittest.mock import MagicMock

import pytest
from more_itertools import flatten

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models.connection import Connection
Expand All @@ -34,7 +35,7 @@
from airflow.utils import timezone

from tests_common.test_utils.compat import GenericTransfer
from tests_common.test_utils.operators.run_deferrable import execute_operator
from tests_common.test_utils.operators.run_deferrable import execute_operator, mock_context
from tests_common.test_utils.providers import get_provider_min_airflow_version

pytestmark = pytest.mark.db_test
Expand All @@ -43,6 +44,12 @@
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = "unit_test_dag"
INSERT_ARGS = {
"commit_every": 1000, # Number of rows inserted in each batch
"executemany": True, # Enable batch inserts
"fast_executemany": True, # Boost performance for MSSQL inserts
"replace": True, # Used for upserts/merges if needed
}
counter = 0


Expand Down Expand Up @@ -175,6 +182,44 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):


class TestGenericTransfer:
mocked_source_hook = MagicMock(conn_name_attr="my_source_conn_id", spec=DbApiHook)
mocked_destination_hook = MagicMock(conn_name_attr="my_destination_conn_id", spec=DbApiHook)
mocked_hooks = {
"my_source_conn_id": mocked_source_hook,
"my_destination_conn_id": mocked_destination_hook,
}

@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None):
return cls.mocked_hooks[conn_id]

@classmethod
def get_connection(cls, conn_id: str):
mocked_hook = cls.get_hook(conn_id=conn_id)
mocked_conn = MagicMock(conn_id=conn_id, spec=Connection)
mocked_conn.get_hook.return_value = mocked_hook
return mocked_conn

def setup_method(self):
# Reset mock states before each test
self.mocked_source_hook.reset_mock()
self.mocked_destination_hook.reset_mock()

# Set up the side effect for paginated read
records = [
[[1, 2], [11, 12], [3, 4], [13, 14]],
[[3, 4], [13, 14]],
]

def get_records_side_effect(sql: str):
if records:
if "LIMIT" not in sql:
return list(flatten(records))
return records.pop(0)
return []

self.mocked_source_hook.get_records.side_effect = get_records_side_effect

def test_templated_fields(self):
dag = DAG(
"test_dag",
Expand Down Expand Up @@ -209,53 +254,45 @@ def test_templated_fields(self):
assert operator.preoperator == "my_preoperator"
assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True}

def test_non_paginated_read(self):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection):
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql="SELECT * FROM HR.EMPLOYEES",
destination_table="NEW_HR.EMPLOYEES",
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
)

operator.execute(context=mock_context(task=operator))

assert self.mocked_source_hook.get_records.call_count == 1
assert self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES"
assert self.mocked_destination_hook.insert_rows.call_count == 1
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}

def test_paginated_read(self):
"""
This unit test is based on the example described in the medium article:
https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f
"""

def create_get_records_side_effect():
records = [
[[1, 2], [11, 12], [3, 4], [13, 14]],
[[3, 4], [13, 14]],
]

def side_effect(sql: str):
if records:
return records.pop(0)
return []

return side_effect

get_records_side_effect = create_get_records_side_effect()

def get_hook(conn_id: str, hook_params: dict | None = None):
mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook)
mocked_hook.get_records.side_effect = get_records_side_effect
return mocked_hook

def get_connection(conn_id: str):
mocked_hook = get_hook(conn_id=conn_id)
mocked_conn = MagicMock(conn_id=conn_id, spec=Connection)
mocked_conn.get_hook.return_value = mocked_hook
return mocked_conn

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection):
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection):
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql="SELECT * FROM HR.EMPLOYEES",
destination_table="NEW_HR.EMPLOYEES",
page_size=1000, # Fetch data in chunks of 1000 rows for pagination
insert_args={
"commit_every": 1000, # Number of rows inserted in each batch
"executemany": True, # Enable batch inserts
"fast_executemany": True, # Boost performance for MSSQL inserts
"replace": True, # Used for upserts/merges if needed
},
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
)

Expand All @@ -267,6 +304,21 @@ def get_connection(conn_id: str):
assert events[1].payload["results"] == [[3, 4], [13, 14]]
assert not events[2].payload["results"]

assert self.mocked_source_hook.get_records.call_count == 3
assert (
self.mocked_source_hook.get_records.call_args_list[0].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0"
)
assert self.mocked_destination_hook.insert_rows.call_count == 2
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}
assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == {
**INSERT_ARGS,
**{"rows": [[3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}

def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
"""
Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher
Expand Down