Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class AthenaHook(AwsBaseHook):
'CANCELLED',
)
SUCCESS_STATES = ('SUCCEEDED',)
TERMINAL_STATES = (
"SUCCEEDED",
"FAILED",
"CANCELLED",
)

def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
super().__init__(client_type='athena', *args, **kwargs) # type: ignore
Expand Down Expand Up @@ -200,16 +205,14 @@ def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] =
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info('Trial %s: Invalid query state. Retrying again', try_number)
elif query_state in self.INTERMEDIATE_STATES:
self.log.info(
'Trial %s: Query is still in an intermediate state - %s', try_number, query_state
)
else:
elif query_state in self.TERMINAL_STATES:
self.log.info(
'Trial %s: Query execution completed. Final state is %s}', try_number, query_state
)
final_query_state = query_state
break
else:
self.log.info('Trial %s: Query is still in non-terminal state - %s', try_number, query_state)
if max_tries and try_number >= max_tries: # Break loop if max_tries reached
final_query_state = query_state
break
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_init(self):

assert self.athena.hook.sleep_time == 0

@mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCESS",))
Copy link
Contributor Author

@victorphoenix3 victorphoenix3 Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'SUCCESS' is not a state for athena queries, "SUCCEEDED" is. The tests were failing because of this discrepancy

@mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, 'get_conn')
def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
Expand All @@ -92,7 +92,7 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec
side_effect=(
"RUNNING",
"RUNNING",
"SUCCESS",
"SUCCEEDED",
),
)
@mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, m
)
assert mock_check_query_status.call_count == 3

@mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCESS",))
@mock.patch.object(AthenaHook, 'check_query_status', side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
@mock.patch.object(AthenaHook, 'get_conn')
def test_return_value(self, mock_conn, mock_run_query, mock_check_query_status):
Expand Down