Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 36 additions & 44 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
"""
from __future__ import annotations

from time import sleep
import warnings
from typing import Any

from botocore.paginate import PageIterator

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait


class AthenaHook(AwsBaseHook):
Expand All @@ -38,8 +40,7 @@ class AthenaHook(AwsBaseHook):
Provide thick wrapper around
:external+boto3:py:class:`boto3.client("athena") <Athena.Client>`.

:param sleep_time: Time (in seconds) to wait between two consecutive calls
to check query status on Athena.
:param sleep_time: obsolete, please use the parameter of `poll_query_status` method instead
:param log_query: Whether to log athena query and other execution params
when it's executed. Defaults to *True*.

Expand All @@ -65,9 +66,20 @@ class AthenaHook(AwsBaseHook):
"CANCELLED",
)

def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = True, **kwargs: Any) -> None:
def __init__(
self, *args: Any, sleep_time: int | None = None, log_query: bool = True, **kwargs: Any
) -> None:
super().__init__(client_type="athena", *args, **kwargs) # type: ignore
self.sleep_time = sleep_time
if sleep_time is not None:
self.sleep_time = sleep_time
warnings.warn(
"The `sleep_time` parameter of the Athena hook is deprecated, "
"please pass this parameter to the poll_query_status method instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
else:
self.sleep_time = 30 # previous default value
self.log_query = log_query

def run_query(
Expand Down Expand Up @@ -229,51 +241,31 @@ def get_query_results_paginator(
return paginator.paginate(**result_params)

def poll_query_status(
self,
query_execution_id: str,
max_polling_attempts: int | None = None,
self, query_execution_id: str, max_polling_attempts: int | None = None, sleep_time: int | None = None
) -> str | None:
"""Poll the state of a submitted query until it reaches final state.

:param query_execution_id: ID of submitted athena query
:param max_polling_attempts: Number of times to poll for query state
before function exits
:param max_polling_attempts: Number of times to poll for query state before function exits
:param sleep_time: Time (in seconds) to wait between two consecutive query status checks.
:return: One of the final states
"""
try_number = 1
final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info(
"Query execution id: %s, trial %s: Invalid query state. Retrying again",
query_execution_id,
try_number,
)
elif query_state in self.TERMINAL_STATES:
self.log.info(
"Query execution id: %s, trial %s: Query execution completed. Final state is %s",
query_execution_id,
try_number,
query_state,
)
final_query_state = query_state
break
else:
self.log.info(
"Query execution id: %s, trial %s: Query is still in non-terminal state - %s",
query_execution_id,
try_number,
query_state,
)
if (
max_polling_attempts and try_number >= max_polling_attempts
): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
sleep(self.sleep_time)
return final_query_state
try:
wait(
waiter=self.get_waiter("query_complete"),
waiter_delay=sleep_time or self.sleep_time,
max_attempts=max_polling_attempts or 120,
args={"QueryExecutionId": query_execution_id},
failure_message=f"Error while waiting for query {query_execution_id} to complete",
status_message=f"Query execution id: {query_execution_id}, "
f"Query is still in non-terminal state",
status_args=["QueryExecution.Status.State"],
)
except AirflowException as error:
# this function does not raise errors to keep previous behavior.
self.log.warning(error)
finally:
return self.check_query_status(query_execution_id)

def get_output_location(self, query_execution_id: str) -> str:
"""Get the output location of the query results in S3 URI format.
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time, log_query=self.log_query)
return AthenaHook(self.aws_conn_id, log_query=self.log_query)

def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena."""
Expand All @@ -104,6 +104,7 @@ def execute(self, context: Context) -> str | None:
query_status = self.hook.poll_query_status(
self.query_execution_id,
max_polling_attempts=self.max_polling_attempts,
sleep_time=self.sleep_time,
)

if query_status in AthenaHook.FAILURE_STATES:
Expand Down Expand Up @@ -139,4 +140,4 @@ def on_kill(self) -> None:
self.log.info(
"Polling Athena for query with id %s to reach final state", self.query_execution_id
)
self.hook.poll_query_status(self.query_execution_id)
self.hook.poll_query_status(self.query_execution_id, sleep_time=self.sleep_time)
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self.max_retries = max_retries

def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(self.query_execution_id, self.max_retries)
state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time)

if state in self.FAILURE_STATES:
raise AirflowException("Athena sensor failed")
Expand All @@ -88,4 +88,4 @@ def poke(self, context: Context) -> bool:
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
return AthenaHook(self.aws_conn_id)
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/waiters/athena.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"version": 2,
"waiters": {
"query_complete": {
"operation": "GetQueryExecution",
"delay": 30,
"maxAttempts": 120,
"acceptors": [
{
"expected": "SUCCEEDED",
"matcher": "path",
"state": "success",
"argument": "QueryExecution.Status.State"
},
{
"expected": "FAILED",
"matcher": "path",
"state": "failure",
"argument": "QueryExecution.Status.State"
},
{
"expected": "CANCELLED",
"matcher": "path",
"state": "failure",
"argument": "QueryExecution.Status.State"
}
]
}
}
}
12 changes: 6 additions & 6 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@

class TestAthenaHook:
def setup_method(self):
self.athena = AthenaHook(sleep_time=0)
self.athena = AthenaHook()

def test_init(self):
assert self.athena.aws_conn_id == "aws_default"
assert self.athena.sleep_time == 0

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_without_token(self, mock_conn):
Expand Down Expand Up @@ -104,7 +103,7 @@ def test_hook_run_query_log_query(self, mock_conn, log):
@mock.patch.object(AthenaHook, "log")
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_query_no_log_query(self, mock_conn, log):
athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False)
athena_hook_no_log_query = AthenaHook(log_query=False)
athena_hook_no_log_query.run_query(
query=MOCK_DATA["query"],
query_context=mock_query_context,
Expand Down Expand Up @@ -176,16 +175,17 @@ def test_hook_get_paginator_with_pagination_config(self, mock_conn):
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_poll_query_when_final(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
result = self.athena.poll_query_status(query_execution_id=MOCK_DATA["query_execution_id"])
result = self.athena.poll_query_status(
query_execution_id=MOCK_DATA["query_execution_id"], sleep_time=0
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == "SUCCEEDED"

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(
query_execution_id=MOCK_DATA["query_execution_id"],
max_polling_attempts=1,
query_execution_id=MOCK_DATA["query_execution_id"], max_polling_attempts=1, sleep_time=0
)
mock_conn.return_value.get_query_execution.assert_called_once()
assert result == "RUNNING"
Expand Down
63 changes: 4 additions & 59 deletions tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def test_init(self):
assert self.athena.client_request_token == MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0

assert self.athena.hook.sleep_time == 0

@mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
Expand All @@ -90,11 +88,7 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec
@mock.patch.object(
AthenaHook,
"check_query_status",
side_effect=(
"RUNNING",
"RUNNING",
"SUCCEEDED",
),
side_effect="SUCCEEDED",
)
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
Expand All @@ -107,39 +101,9 @@ def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 3

@mock.patch.object(
AthenaHook,
"check_query_status",
side_effect=(
None,
None,
),
)
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status):
with pytest.raises(Exception):
self.athena.execute({})
mock_run_query.assert_called_once_with(
MOCK_DATA["query"],
query_context,
result_configuration,
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 3

@mock.patch.object(AthenaHook, "get_state_change_reason")
@mock.patch.object(
AthenaHook,
"check_query_status",
side_effect=(
"RUNNING",
"FAILED",
),
)
@mock.patch.object(AthenaHook, "check_query_status", return_value="FAILED")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failure_query(
Expand All @@ -154,18 +118,9 @@ def test_hook_run_failure_query(
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 2
assert mock_get_state_change_reason.call_count == 1

@mock.patch.object(
AthenaHook,
"check_query_status",
side_effect=(
"RUNNING",
"RUNNING",
"CANCELLED",
),
)
@mock.patch.object(AthenaHook, "check_query_status", return_value="CANCELLED")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status):
Expand All @@ -178,17 +133,8 @@ def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_qu
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 3

@mock.patch.object(
AthenaHook,
"check_query_status",
side_effect=(
"RUNNING",
"RUNNING",
"RUNNING",
),
)
@mock.patch.object(AthenaHook, "check_query_status", return_value="RUNNING")
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status):
Expand All @@ -201,7 +147,6 @@ def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, m
MOCK_DATA["client_request_token"],
MOCK_DATA["workgroup"],
)
assert mock_check_query_status.call_count == 3

@mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
Expand Down