Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def execute(self, context: Context):
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)

results = self.destination_hook.get_records(self.sql)
results = self.source_hook.get_records(self.sql)

self.log.info("Inserting rows into %s", self.destination_conn_id)
self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,53 @@ def test_templated_fields(self):
assert operator.preoperator == "my_preoperator"
assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True}

def test_not_paginated_transfer(self):

mocked_source_hook = mock.MagicMock(conn_name_attr='my_source_conn_id', spec=DbApiHook)
mocked_destination_hook = mock.MagicMock(conn_name_attr='my_destination_conn_id', spec=DbApiHook)

def get_hook(conn_id: str, hook_params: dict | None = None):
return {
'my_source_conn_id': mocked_source_hook,
'my_destination_conn_id': mocked_destination_hook
}[conn_id]

def get_connection(conn_id: str):
mocked_hook = get_hook(conn_id=conn_id)
mocked_conn = mock.MagicMock(conn_id=conn_id, spec=Connection)
mocked_conn.get_hook.return_value = mocked_hook
return mocked_conn

sql_statement = "SELECT * FROM generic_transfer"
preoperator_statements = [
"DROP TABLE IF EXISTS test_generic_transfer",
"CREATE TABLE test_generic_transfer(LIKE generic_transfer INCLUDING INDEXES)"
]
destination_table = "test_generic_transfer"
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql=sql_statement,
preoperator=preoperator_statements,
destination_table=destination_table
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection):
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook):
execute_operator(operator)

assert mocked_destination_hook.run.call_count == 1
assert mocked_destination_hook.run.call_args_list[0].args[0] == preoperator_statements
assert not mocked_source_hook.run.called

assert mocked_source_hook.get_records.call_count == 1
assert mocked_source_hook.get_records.call_args_list[0].args[0] == sql_statement
assert not mocked_destination_hook.get_records.called

assert mocked_destination_hook.insert_rows.call_count == 1
assert mocked_destination_hook.insert_rows.call_args_list[0].kwargs['table'] == destination_table
assert not mocked_source_hook.insert_rows.called

def test_paginated_read(self):
"""
This unit test is based on the example described in the medium article:
Expand All @@ -228,12 +275,15 @@ def side_effect(sql: str):

return side_effect

get_records_side_effect = create_get_records_side_effect()
mocked_source_hook = mock.MagicMock(conn_name_attr='my_source_conn_id', spec=DbApiHook)
mocked_source_hook.get_records.side_effect = create_get_records_side_effect()
mocked_destination_hook = mock.MagicMock(conn_name_attr='my_destination_conn_id', spec=DbApiHook)

def get_hook(conn_id: str, hook_params: dict | None = None):
mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook)
mocked_hook.get_records.side_effect = get_records_side_effect
return mocked_hook
return {
'my_source_conn_id': mocked_source_hook,
'my_destination_conn_id': mocked_destination_hook
}[conn_id]

def get_connection(conn_id: str):
mocked_hook = get_hook(conn_id=conn_id)
Expand Down Expand Up @@ -266,6 +316,8 @@ def get_connection(conn_id: str):
assert events[0].payload["results"] == [[1, 2], [11, 12], [3, 4], [13, 14]]
assert events[1].payload["results"] == [[3, 4], [13, 14]]
assert not events[2].payload["results"]
assert mocked_source_hook.get_records.called
assert mocked_destination_hook.insert_rows.called

def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
"""
Expand Down
Loading