Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 66 additions & 36 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import TYPE_CHECKING, Any
from uuid import uuid4

from botocore.exceptions import WaiterError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
Expand Down Expand Up @@ -666,6 +668,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param terminate_job_flow_on_failure: If True, attempts best-effort termination of the EMR job flow
when a failure occurs after the job flow has been created. Cleanup failures do not mask the
original exception. (default: True)
"""

aws_hook_class = EmrHook
Expand All @@ -692,6 +697,7 @@ def __init__(
waiter_max_attempts: int | None = None,
waiter_delay: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
terminate_job_flow_on_failure: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -701,6 +707,7 @@ def __init__(
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable
self.terminate_job_flow_on_failure = terminate_job_flow_on_failure

if wait_policy is not None:
warnings.warn(
Expand Down Expand Up @@ -741,49 +748,72 @@ def execute(self, context: Context) -> str | None:

self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
)
if self._job_flow_id:
EmrLogsLink.persist(
try:
EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
if self.wait_for_completion:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
else:
self.hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
if self._job_flow_id:
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
)
return self._job_flow_id
if self.wait_for_completion:
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]

if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
else:
self.hook.get_waiter(waiter_name).wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)
return self._job_flow_id

# Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors).
except WaiterError:
if self._job_flow_id:
if self.terminate_job_flow_on_failure:
self.log.warning(
"Task failed after creating EMR job flow %s.",
self._job_flow_id,
)
try:
self.log.info(
"Attempting termination of EMR job flow %s.",
self._job_flow_id,
)

self.hook.conn.terminate_job_flows(JobFlowIds=[self._job_flow_id])
except Exception:
self.log.exception(
"Failed to terminate EMR job flow %s after task failure.",
self._job_flow_id,
)
raise

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
validated_event = validate_execute_complete_event(event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from unittest.mock import MagicMock, patch

import pytest
from botocore.exceptions import ClientError, WaiterError
from botocore.waiter import Waiter
from jinja2 import StrictUndefined

Expand Down Expand Up @@ -232,6 +233,7 @@ def test_create_job_flow_deferrable(self, mocked_hook_client):

self.operator.deferrable = True
self.operator.wait_for_completion = True

with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)

Expand Down Expand Up @@ -261,3 +263,77 @@ def test_wait_policy_deprecation_warning(self):
task_id=TASK_ID,
wait_policy=WaitPolicy.WAIT_FOR_COMPLETION,
)

def test_cleanup_on_post_create_failure(self, mocked_hook_client):
"""
Ensure that if the job flow is created successfully but a subsequent
post-create step fails (e.g. waiter / DescribeCluster),
the operator attempts best-effort cleanup.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.wait_for_completion = True
self.operator.terminate_job_flow_on_failure = True

waiter_error = WaiterError(
"ClusterRunning",
"You are not authorized to perform this operation",
{},
)

with (
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
):
mock_get_waiter.return_value.wait.side_effect = waiter_error

with pytest.raises(WaiterError) as exc:
self.operator.execute(self.mock_context)

# Original exception must be propagated unchanged
assert exc.value is waiter_error

# Cleanup must be attempted
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])

def test_cleanup_failure_does_not_mask_original_exception(self, mocked_hook_client):
"""
Ensure that failure during cleanup does not override
the original post-create exception.
"""
mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

self.operator.wait_for_completion = True
self.operator.terminate_job_flow_on_failure = True

waiter_error = WaiterError(
"ClusterRunning",
"You are not authorized to perform this operation",
{},
)

cleanup_error = ClientError(
error_response={
"Error": {
"Code": "UnauthorizedOperation",
"Message": "You are not authorized to perform this operation",
}
},
operation_name="TerminateJobFlows",
)

with (
patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
patch.object(self.operator.hook.conn, "terminate_job_flows") as mock_terminate,
):
mock_get_waiter.return_value.wait.side_effect = waiter_error
mock_terminate.side_effect = cleanup_error

with pytest.raises(WaiterError) as exc:
self.operator.execute(self.mock_context)

# Original exception must be preserved
assert exc.value is waiter_error

# Cleanup attempted despite failure
mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])