Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ class EmrContainerHook(AwsBaseHook):
"CANCEL_PENDING",
)
SUCCESS_STATES = ("COMPLETED",)
TERMINAL_STATES = (
"COMPLETED",
"FAILED",
"CANCELLED",
"CANCEL_PENDING",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "CANCEL_PENDING" really a terminal state?

Looks good to me other than this question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"CANCEL_PENDING" is being treated as a failure state for the EMR container sensor, which is why I included it as a terminal state. Should it be removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the EMR behavior well enough to say. If you feel it's intentional and the right thing to do, then leave it in. Just asking because the name strikes me as an intermediate state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going by the code here and here, I'll leave it in as a terminal state.

)

def __init__(self, *args: Any, virtual_cluster_id: Optional[str] = None, **kwargs: Any) -> None:
super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
Expand Down Expand Up @@ -228,19 +234,16 @@ def poll_query_status(
try_number = 1
final_query_state = None # Query state when query reaches final state or max_tries reached

# TODO: Make this logic a little bit more robust.
# Currently this polls until the state is *not* one of the INTERMEDIATE_STATES
# While that should work in most cases...it might not. :)
while True:
query_state = self.check_query_status(job_id)
if query_state is None:
self.log.info("Try %s: Invalid query state. Retrying again", try_number)
elif query_state in self.INTERMEDIATE_STATES:
self.log.info("Try %s: Query is still in an intermediate state - %s", try_number, query_state)
else:
elif query_state in self.TERMINAL_STATES:
self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
final_query_state = query_state
break
else:
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
if max_tries and try_number >= max_tries: # Break loop if max_tries reached
final_query_state = query_state
break
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@
'virtualClusterId': 'vc1234',
}

JOB1_RUN_DESCRIPTION = {
'jobRun': {
'id': 'job123456',
'virtualClusterId': 'vc1234',
'state': 'COMPLETED',
}
}

JOB2_RUN_DESCRIPTION = {
'jobRun': {
'id': 'job123456',
'virtualClusterId': 'vc1234',
'state': 'RUNNING',
}
}


class TestEmrContainerHook(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -55,3 +71,29 @@ def test_submit_job(self, mock_session):
client_request_token="uuidtoken",
)
assert emr_containers_job == 'job123456'

@mock.patch("boto3.session.Session")
def test_query_status_polling_when_terminal(self, mock_session):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB1_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(job_id='job123456')
# should only poll once since query is already in terminal state
emr_client_mock.describe_job_run.assert_called_once()
assert query_status == 'COMPLETED'

@mock.patch("boto3.session.Session")
def test_query_status_polling_with_timeout(self, mock_session):
emr_client_mock = mock.MagicMock()
emr_session_mock = mock.MagicMock()
emr_session_mock.client.return_value = emr_client_mock
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(job_id='job123456', max_tries=2)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert query_status == 'RUNNING'