Skip to content

Commit

Permalink
Allow AWSAthenaHook to get more than 1000/first page of results (#6075)
Browse files Browse the repository at this point in the history
Co-authored-by: Dylan Joss <dylanjoss@gmail.com>
GitOrigin-RevId: 07b81029ebc2a296fb54181f2cec11fcc7704d9d
  • Loading branch information
2 people authored and Cloud Composer Team committed Sep 12, 2024
1 parent e63024c commit 8b410f8
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 8 deletions.
64 changes: 57 additions & 7 deletions airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ def run_query(self, query, query_context, result_configuration, client_request_t
:type workgroup: str
:return: str
"""
response = self.get_conn().start_query_execution(QueryString=query,
ClientRequestToken=client_request_token,
QueryExecutionContext=query_context,
ResultConfiguration=result_configuration,
WorkGroup=workgroup)
params = {
'QueryString': query,
'QueryExecutionContext': query_context,
'ResultConfiguration': result_configuration,
'WorkGroup': workgroup
}
if client_request_token:
params['ClientRequestToken'] = client_request_token
response = self.get_conn().start_query_execution(**params)
query_execution_id = response['QueryExecutionId']
return query_execution_id

Expand Down Expand Up @@ -109,13 +113,17 @@ def get_state_change_reason(self, query_execution_id):
# The error is being absorbed to implement retries.
return reason # pylint: disable=lost-exception

def get_query_results(self, query_execution_id):
def get_query_results(self, query_execution_id, next_token_id=None, max_results=1000):
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param next_token_id: The token that specifies where to start pagination.
:type next_token_id: str
:param max_results: The maximum number of results (rows) to return in this request.
:type max_results: int
:return: dict
"""
query_state = self.check_query_status(query_execution_id)
Expand All @@ -125,7 +133,49 @@ def get_query_results(self, query_execution_id):
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
return self.get_conn().get_query_results(QueryExecutionId=query_execution_id)
result_params = {
'QueryExecutionId': query_execution_id,
'MaxResults': max_results
}
if next_token_id:
result_params['NextToken'] = next_token_id
return self.get_conn().get_query_results(**result_params)

def get_query_results_paginator(self, query_execution_id, max_items=None,
page_size=None, starting_token=None):
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else a paginator to iterate through pages of results. If you
wish to get all results at once, call build_full_result() on the returned PageIterator
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param max_items: The total number of items to return.
:type max_items: int
:param page_size: The size of each page.
:type page_size: int
:param starting_token: A token to specify where to start paginating.
:type starting_token: str
:return: PageIterator
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error('Invalid Query state (null)')
return None
if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
result_params = {
'QueryExecutionId': query_execution_id,
'PaginationConfig': {
'MaxItems': max_items,
'PageSize': page_size,
'StartingToken': starting_token

}
}
paginator = self.get_conn().get_paginator('get_query_results')
return paginator.paginate(**result_params)

def poll_query_status(self, query_execution_id, max_tries=None):
"""
Expand Down
172 changes: 172 additions & 0 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import unittest
from unittest import mock

from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook

MOCK_DATA = {
'query': 'SELECT * FROM TEST_TABLE',
'database': 'TEST_DATABASE',
'outputLocation': 's3://test_s3_bucket/',
'client_request_token': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
'workgroup': 'primary',
'query_execution_id': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
'next_token_id': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
'max_items': 1000
}

mock_query_context = {
'Database': MOCK_DATA['database']
}
mock_result_configuration = {
'OutputLocation': MOCK_DATA['outputLocation']
}

MOCK_RUNNING_QUERY_EXECUTION = {'QueryExecution': {'Status': {'State': 'RUNNING'}}}
MOCK_SUCCEEDED_QUERY_EXECUTION = {'QueryExecution': {'Status': {'State': 'SUCCEEDED'}}}

MOCK_QUERY_EXECUTION = {'QueryExecutionId': MOCK_DATA['query_execution_id']}


class TestAWSAthenaHook(unittest.TestCase):

def setUp(self):
self.athena = AWSAthenaHook(sleep_time=0)

def test_init(self):
self.assertEqual(self.athena.aws_conn_id, 'aws_default')
self.assertEqual(self.athena.sleep_time, 0)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_query_without_token(self, mock_conn):
mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
result = self.athena.run_query(query=MOCK_DATA['query'],
query_context=mock_query_context,
result_configuration=mock_result_configuration)
expected_call_params = {
'QueryString': MOCK_DATA['query'],
'QueryExecutionContext': mock_query_context,
'ResultConfiguration': mock_result_configuration,
'WorkGroup': MOCK_DATA['workgroup']
}
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
self.assertEqual(result, MOCK_DATA['query_execution_id'])

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_query_with_token(self, mock_conn):
mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
result = self.athena.run_query(query=MOCK_DATA['query'],
query_context=mock_query_context,
result_configuration=mock_result_configuration,
client_request_token=MOCK_DATA['client_request_token'])
expected_call_params = {
'QueryString': MOCK_DATA['query'],
'QueryExecutionContext': mock_query_context,
'ResultConfiguration': mock_result_configuration,
'ClientRequestToken': MOCK_DATA['client_request_token'],
'WorkGroup': MOCK_DATA['workgroup']
}
mock_conn.return_value.start_query_execution.assert_called_with(**expected_call_params)
self.assertEqual(result, MOCK_DATA['query_execution_id'])

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
self.assertIsNone(result)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_query_results_with_default_params(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'])
expected_call_params = {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'MaxResults': 1000
}
mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_query_results_with_next_token(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results(query_execution_id=MOCK_DATA['query_execution_id'],
next_token_id=MOCK_DATA['next_token_id'])
expected_call_params = {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'NextToken': MOCK_DATA['next_token_id'],
'MaxResults': 1000
}
mock_conn.return_value.get_query_results.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_paginator_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
self.assertIsNone(result)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_paginator_with_default_params(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'])
expected_call_params = {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'PaginationConfig': {
'MaxItems': None,
'PageSize': None,
'StartingToken': None

}
}
mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_get_paginator_with_pagination_config(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
self.athena.get_query_results_paginator(query_execution_id=MOCK_DATA['query_execution_id'],
max_items=MOCK_DATA['max_items'],
page_size=MOCK_DATA['max_items'],
starting_token=MOCK_DATA['next_token_id'])
expected_call_params = {
'QueryExecutionId': MOCK_DATA['query_execution_id'],
'PaginationConfig': {
'MaxItems': MOCK_DATA['max_items'],
'PageSize': MOCK_DATA['max_items'],
'StartingToken': MOCK_DATA['next_token_id']

}
}
mock_conn.return_value.get_paginator.return_value.paginate.assert_called_with(**expected_call_params)

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_poll_query_when_final(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION
result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'])
mock_conn.return_value.get_query_execution.assert_called_once()
self.assertEqual(result, 'SUCCEEDED')

@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_poll_query_with_timeout(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
result = self.athena.poll_query_status(query_execution_id=MOCK_DATA['query_execution_id'],
max_tries=1)
mock_conn.return_value.get_query_execution.assert_called_once()
self.assertEqual(result, 'RUNNING')


if __name__ == '__main__':
unittest.main()
1 change: 0 additions & 1 deletion tests/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)

MISSING_TEST_FILES = {
'tests/providers/amazon/aws/hooks/test_athena.py',
'tests/providers/apache/cassandra/sensors/test_record.py',
'tests/providers/apache/cassandra/sensors/test_table.py',
'tests/providers/apache/hdfs/sensors/test_web_hdfs.py',
Expand Down

0 comments on commit 8b410f8

Please sign in to comment.