Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1274,16 +1274,32 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)
else:
waiter = self.hook.get_waiter("serverless_job_completed")
wait(
waiter=waiter,
waiter_max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
args={"applicationId": self.application_id, "jobRunId": self.job_id},
failure_message="Serverless Job failed",
status_message="Serverless Job status is",
status_args=["jobRun.state", "jobRun.stateDetails"],
)
try:
waiter = self.hook.get_waiter("serverless_job_completed")
wait(
waiter=waiter,
waiter_max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
args={"applicationId": self.application_id, "jobRunId": self.job_id},
failure_message="Serverless Job failed",
status_message="Serverless Job status is",
status_args=["jobRun.state", "jobRun.stateDetails"],
)
except AirflowException as e:
if "Waiter error: max attempts reached" in str(e):
self.log.info(
"Cancelling EMR Serverless job %s due to max waiter attempts reached", self.job_id
)
try:
self.hook.conn.cancel_job_run(
applicationId=self.application_id, jobRunId=self.job_id
)
except Exception:
self.log.exception(
"Failed to cancel EMR Serverless job %s after waiter timeout",
self.job_id,
)
raise

return self.job_id

Expand All @@ -1292,7 +1308,13 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None

if validated_event["status"] == "success":
self.log.info("Serverless job completed")
return validated_event["job_id"]
return validated_event["job_details"]["job_id"]
self.log.info("Cancelling EMR Serverless job %s", self.job_id)
self.hook.conn.cancel_job_run(
applicationId=validated_event["job_details"]["application_id"],
jobRunId=validated_event["job_details"]["job_id"],
)
raise AirflowException("EMR Serverless job failed or timed out in deferrable mode")

def on_kill(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def __init__(
failure_message="Serverless Job failed",
status_message="Serverless Job status is",
status_queries=["jobRun.state", "jobRun.stateDetails"],
return_key="job_id",
return_value=job_id,
return_key="job_details",
return_value={"application_id": application_id, "job_id": job_id},
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,39 +502,198 @@ def test_job_run_app_not_started(self, mock_conn, mock_get_waiter):
name=default_name,
)

@mock.patch("time.sleep", return_value=True)
@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_job_run_app_not_started_app_failed(self, mock_conn, mock_get_waiter, mock_time):
error1 = WaiterError(
name="test_name",
reason="test-reason",
last_response={"application": {"state": "CREATING", "stateDetails": "test-details"}},
def test_execute_max_attempts_cancel_job(self, mock_conn, mock_get_waiter, sleep_mock):
application_id = "test-app-id"
job_run_id = "test-job-id"
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
mock_get_waiter.return_value.wait.side_effect = AirflowException("Waiter error: max attempts reached")
operator = EmrServerlessStartJobOperator(
task_id="test_task",
application_id=application_id,
execution_role_arn="arn:aws:iam::123456789012:role/test-role",
job_driver={"sparkSubmit": {"entryPoint": "s3://my-bucket/my-script.py"}},
client_request_token="token",
configuration_overrides={},
waiter_delay=1,
waiter_max_attempts=1,
wait_for_completion=True,
)

mock_context = mock.MagicMock()
with mock.patch.object(operator.log, "info") as mock_log_info:
with pytest.raises(AirflowException, match="Waiter error: max attempts reached"):
operator.execute(mock_context)
mock_conn.cancel_job_run.assert_called_once_with(
applicationId=application_id, jobRunId=job_run_id
)
log_msgs = [call.args[0] for call in mock_log_info.call_args_list]
assert any("Cancelling EMR Serverless job" in msg for msg in log_msgs)

@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_execute_max_attempts_cancel_job_cancel_fails(self, mock_conn, mock_get_waiter, sleep_mock):
application_id = "test-app-id"
job_run_id = "test-job-id"
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
mock_get_waiter.return_value.wait.side_effect = AirflowException("Waiter error: max attempts reached")
mock_conn.cancel_job_run.side_effect = Exception(
"Failed to cancel EMR Serverless job after waiter timeout"
)
error2 = WaiterError(
name="test_name",
reason="Waiter encountered a terminal failure state:",
last_response={"application": {"state": "TERMINATED", "stateDetails": "test-details"}},
operator = EmrServerlessStartJobOperator(
task_id="test_task",
application_id=application_id,
execution_role_arn="arn:aws:iam::123456789012:role/test-role",
job_driver={"sparkSubmit": {"entryPoint": "s3://my-bucket/my-script.py"}},
client_request_token="token",
configuration_overrides={},
waiter_delay=1,
waiter_max_attempts=1,
wait_for_completion=True,
)
mock_context = mock.MagicMock()

with (
mock.patch.object(operator.log, "exception") as mock_log_exception,
mock.patch.object(operator.log, "info") as mock_log_info,
):
with pytest.raises(AirflowException, match="Waiter error: max attempts reached"):
operator.execute(mock_context)
mock_conn.cancel_job_run.assert_called_once_with(
applicationId=application_id,
jobRunId=job_run_id,
)
log_msgs = [call.args[0] for call in mock_log_exception.call_args_list]
assert any("Failed to cancel EMR Serverless job" in msg for msg in log_msgs)
info_msgs = [call.args[0] for call in mock_log_info.call_args_list]
assert any("Cancelling EMR Serverless job" in msg for msg in info_msgs)

@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_execute_waiter_other_exception_does_not_cancel(self, mock_conn, mock_get_waiter, sleep_mock):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": "job-id",
"ResponseMetadata": {"HTTPStatusCode": 200},
}
mock_get_waiter.return_value.wait.side_effect = AirflowException("Other failure")

operator = EmrServerlessStartJobOperator(
task_id="test_task",
application_id="app-id",
execution_role_arn="arn",
job_driver={"sparkSubmit": {"entryPoint": "s3://x"}},
client_request_token="token",
waiter_delay=1,
waiter_max_attempts=1,
wait_for_completion=True,
)
with pytest.raises(AirflowException, match="Other failure"):
operator.execute(mock.MagicMock())
mock_conn.cancel_job_run.assert_not_called()

@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_execute_no_wait_for_completion_does_not_cancel(self, mock_conn, mock_get_waiter, sleep_mock):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": "job-id",
"ResponseMetadata": {"HTTPStatusCode": 200},
}

operator = EmrServerlessStartJobOperator(
task_id="test_task",
application_id="app-id",
execution_role_arn="arn",
job_driver={"sparkSubmit": {"entryPoint": "s3://x"}},
client_request_token="token",
wait_for_completion=False,
)
mock_get_waiter().wait.side_effect = [error1, error2]

job_id = operator.execute(mock.MagicMock())
assert job_id == "job-id"
mock_conn.cancel_job_run.assert_not_called()

@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_execute_deferrable_does_not_cancel(self, mock_conn, mock_get_waiter, sleep_mock):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"jobRunId": "job-id",
"ResponseMetadata": {"HTTPStatusCode": 200},
}

operator = EmrServerlessStartJobOperator(
task_id=task_id,
client_request_token=client_request_token,
task_id="test_task",
application_id="app-id",
execution_role_arn="arn",
job_driver={"sparkSubmit": {"entryPoint": "s3://x"}},
client_request_token="token",
wait_for_completion=True,
deferrable=True,
)
operator.defer = mock.MagicMock()
operator.execute(mock.MagicMock())
operator.defer.assert_called_once()
mock_conn.cancel_job_run.assert_not_called()

@mock.patch("time.sleep", return_value=None)
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_execute_complete_deferrable_failure_triggers_cancel(
self, mock_conn, mock_get_waiter, sleep_mock
):
application_id = "test-app-id"
job_run_id = "test-job-id"
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
operator = EmrServerlessStartJobOperator(
task_id="test_task",
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=job_driver,
configuration_overrides=configuration_overrides,
execution_role_arn="arn:aws:iam::123456789012:role/test-role",
job_driver={"sparkSubmit": {"entryPoint": "s3://bucket/script.py"}},
client_request_token="token",
configuration_overrides={},
waiter_delay=1,
waiter_max_attempts=3,
wait_for_completion=True,
deferrable=True,
)
with pytest.raises(AirflowException) as ex_message:
operator.execute(self.mock_context)
assert "Serverless Application failed to start:" in str(ex_message.value)
assert operator.wait_for_completion is True
assert mock_get_waiter().wait.call_count == 2
operator.job_id = job_run_id
failed_event = {
"status": "error",
"job_details": {"application_id": application_id, "job_id": job_run_id},
}

mock_context = mock.MagicMock()
with mock.patch.object(operator.log, "info") as mock_log:
with pytest.raises(
AirflowException, match="EMR Serverless job failed or timed out in deferrable mode"
):
operator.execute_complete(mock_context, failed_event)
mock_conn.cancel_job_run.assert_called_once_with(
applicationId=application_id,
jobRunId=job_run_id,
)
log_msgs = [call.args[0] for call in mock_log.call_args_list]
assert any("Cancelling EMR Serverless job" in msg for msg in log_msgs)

@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
Expand Down