Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,21 @@ def get_oauth_token(
conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
)

def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
def get_request_url_header_params(
self, query_id: str, url_suffix: str | None = None
) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Build the request header Url with account name identifier and query id from the connection params.

:param query_id: statement handles query ids for the individual statements.
:param url_suffix: Optional path suffix to append to the URL. Must start with '/', e.g. '/cancel' or '/result'.
"""
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id)}
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements/{query_id}"
if url_suffix:
url += url_suffix
return header, params, url

def check_query_output(self, query_ids: list[str]) -> None:
Expand Down Expand Up @@ -413,6 +418,16 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str |
status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params)
return self._process_response(status_code, resp)

def _cancel_sql_api_query_execution(self, query_id: str) -> dict[str, str | list[str]]:
self.log.info("Cancelling query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id, "/cancel")
status_code, resp = self._make_api_call_with_retries("POST", url, header, params)
return self._process_response(status_code, resp)

def cancel_queries(self, query_ids: list[str]) -> None:
for query_id in query_ids:
self._cancel_sql_api_query_execution(query_id)

@staticmethod
def _should_retry_on_error(exception) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,10 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] |
self._hook.query_ids = self.query_ids
else:
self.log.info("%s completed successfully.", self.task_id)

def on_kill(self) -> None:
"""Cancel the running query."""
if self.query_ids:
self.log.info("Cancelling the query ids %s", self.query_ids)
self._hook.cancel_queries(self.query_ids)
self.log.info("Query ids %s cancelled successfully", self.query_ids)
Original file line number Diff line number Diff line change
Expand Up @@ -1433,3 +1433,39 @@ def test_make_api_call_with_retries_json_decode_error_prevention(self, mock_requ

failed_response.raise_for_status.assert_called_once()
failed_response.json.assert_not_called()

@mock.patch(f"{HOOK_PATH}.get_request_url_header_params")
def test_cancel_sql_api_query_execution(self, mock_get_url_header_params, mock_requests):
"""Test _cancel_sql_api_query_execution makes POST request with /cancel suffix."""
query_id = "test-query-id"
mock_get_url_header_params.return_value = (
HEADERS,
{"requestId": "uuid"},
f"{API_URL}/{query_id}/cancel",
)
mock_requests.request.return_value = create_successful_response_mock(
{"status": "success", "message": "Statement cancelled."}
)

hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
hook._cancel_sql_api_query_execution(query_id)

mock_get_url_header_params.assert_called_once_with(query_id, "/cancel")
mock_requests.request.assert_called_once_with(
method="post",
url=f"{API_URL}/{query_id}/cancel",
headers=HEADERS,
params={"requestId": "uuid"},
json=None,
)

@mock.patch(f"{HOOK_PATH}._cancel_sql_api_query_execution")
def test_cancel_queries(self, mock_cancel_execution):
"""Test cancel_queries calls _cancel_sql_api_query_execution for each query id."""
query_ids = ["query-1", "query-2", "query-3"]

hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
hook.cancel_queries(query_ids)

assert mock_cancel_execution.call_count == 3
mock_cancel_execution.assert_has_calls([call("query-1"), call("query-2"), call("query-3")])
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,33 @@ def test_snowflake_sql_api_execute_operator_polling_failed(
with pytest.raises(AirflowException):
operator.execute(context=None)
mock_check_query_output.assert_not_called()

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
def test_snowflake_sql_api_on_kill_cancels_queries(self, mock_cancel_queries):
"""Test that on_kill cancels running queries."""
operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
)
operator.query_ids = ["uuid1", "uuid2"]

operator.on_kill()

mock_cancel_queries.assert_called_once_with(["uuid1", "uuid2"])

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.cancel_queries")
def test_snowflake_sql_api_on_kill_no_queries(self, mock_cancel_queries):
"""Test that on_kill does nothing when no query ids exist."""
operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
)
operator.query_ids = []

operator.on_kill()

mock_cancel_queries.assert_not_called()