Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 82 additions & 50 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from time import sleep
from typing import TYPE_CHECKING, Any

from botocore.exceptions import WaiterError

from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
Expand Down Expand Up @@ -394,6 +396,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
(default: False)
:param do_xcom_push: If True, the operator will push the ECS task ARN to XCom with key 'ecs_task_arn'.
Additionally, if logs are fetched, the last log message will be pushed to XCom with the key 'return_value'. (default: False)
:param stop_task_on_failure: If True, attempt to stop the ECS task if the Airflow task fails
after the ECS task has started. (default: True)
"""

ui_color = "#f0ede4"
Expand Down Expand Up @@ -457,6 +461,7 @@ def __init__(
# Set the default waiter duration to 70 days (attempts*delay)
# Airflow execution_timeout handles task timeout
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
stop_task_on_failure: bool = True,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -495,6 +500,7 @@ def __init__(
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.stop_task_on_failure = stop_task_on_failure

if self._aws_logs_enabled() and not self.wait_for_completion:
self.log.warning(
Expand All @@ -513,60 +519,86 @@ def execute(self, context):
)
self.log.info("EcsOperator overrides: %s", self.overrides)

if self.reattach:
# Generate deterministic UUID which refers to unique TaskInstanceKey
ti: TaskInstance = context["ti"]
self._started_by = generate_uuid(*map(str, [ti.dag_id, ti.task_id, ti.run_id, ti.map_index]))
self.log.info("Try to find run with startedBy=%r", self._started_by)
self._try_reattach_task(started_by=self._started_by)

if not self.arn:
# start the task except if we reattached to an existing one just before.
self._start_task()

if self.do_xcom_push:
context["ti"].xcom_push(key="ecs_task_arn", value=self.arn)

if self.deferrable:
self.defer(
trigger=TaskDoneTrigger(
cluster=self.cluster,
task_arn=self.arn,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region_name,
log_group=self.awslogs_group,
log_stream=self._get_logs_stream_name(),
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
# self.defer raises a special exception, so execution stops here in this case.

if not self.wait_for_completion:
return
task_started_by_this_run = False

try:
if self.reattach:
# Generate deterministic UUID which refers to unique TaskInstanceKey
ti: TaskInstance = context["ti"]
self._started_by = generate_uuid(*map(str, [ti.dag_id, ti.task_id, ti.run_id, ti.map_index]))
self.log.info("Try to find run with startedBy=%r", self._started_by)
self._try_reattach_task(started_by=self._started_by)

if not self.arn:
# start the task except if we reattached to an existing one just before.
self._start_task()
task_started_by_this_run = True

if self.do_xcom_push:
context["ti"].xcom_push(key="ecs_task_arn", value=self.arn)

if self.deferrable:
self.defer(
trigger=TaskDoneTrigger(
cluster=self.cluster,
task_arn=self.arn,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region_name,
log_group=self.awslogs_group,
log_stream=self._get_logs_stream_name(),
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
# self.defer raises a special exception, so execution stops here in this case.

if not self.wait_for_completion:
return

if self._aws_logs_enabled():
self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()

try:
self._wait_for_task_ended()
finally:
self.task_log_fetcher.stop()
self.task_log_fetcher.join()
else:
self._wait_for_task_ended()

if self._aws_logs_enabled():
self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()
self._after_execution()

try:
self._wait_for_task_ended()
finally:
self.task_log_fetcher.stop()
self.task_log_fetcher.join()
else:
self._wait_for_task_ended()
if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
return None
except WaiterError:
# Best-effort cleanup when post-initiation steps fail (e.g. IAM/permission errors).
if task_started_by_this_run and self.arn:
self.log.warning(
"Execution failed after ECS task %s was started by this task instance.", self.arn
)

self._after_execution()
if self.stop_task_on_failure:
try:
self.log.warning("Attempting termination of ECS task %s.", self.arn)

if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
return None
self.client.stop_task(
cluster=self.cluster,
task=self.arn,
reason="Task failed after creation; cleanup by Airflow",
)
except Exception:
self.log.exception(
"Failed while attempting to stop ECS task %s",
self.arn,
)
raise

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str | None:
validated_event = validate_execute_complete_event(event)
Expand Down
92 changes: 92 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import boto3
import pytest
from botocore.exceptions import ClientError, WaiterError

from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
Expand Down Expand Up @@ -852,6 +853,97 @@ def test_container_name_not_polled(self, client_mock):
self.ecs._start_task()
assert client_mock.describe_tasks.call_count == 0

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
def test_cleanup_on_post_start_failure(self, wait_mock, client_mock):
"""
Ensure that if an ECS task is started successfully but a subsequent
post-start step fails (e.g. DescribeTasks permission denied),
the operator attempts best-effort cleanup.
"""
self.set_up_operator(
launch_type="FARGATE",
capacity_provider_strategy=None,
platform_version=None,
tags=None,
volume_configurations=None,
stop_task_on_failure=True,
)

client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES

waiter_error = WaiterError(
"AccessDeniedException",
"Not authorized to perform ecs:DescribeTasks",
{},
)

wait_mock.side_effect = waiter_error

mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

with pytest.raises(WaiterError) as exc:
self.ecs.execute(mock_context)

# Original exception must propagate unchanged.
assert exc.value is waiter_error

# Cleanup must be attempted.
client_mock.stop_task.assert_called_once_with(
cluster="c",
task=f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
reason=mock.ANY,
)

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
def test_cleanup_failure_does_not_mask_original_exception(self, wait_mock, client_mock):
"""
Ensure that failure during ECS cleanup does not override
the original post-start exception.
"""
self.set_up_operator(
launch_type="FARGATE",
capacity_provider_strategy=None,
platform_version=None,
tags=None,
volume_configurations=None,
stop_task_on_failure=True,
)

client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES

waiter_error = WaiterError(
"AccessDeniedException",
"Not authorized to perform ecs:DescribeTasks",
{},
)
wait_mock.side_effect = waiter_error

cleanup_error = ClientError(
error_response={
"Error": {
"Code": "AccessDeniedException",
"Message": "Not authorized to perform ecs:StopTask",
}
},
operation_name="StopTask",
)
client_mock.stop_task.side_effect = cleanup_error

mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti, "task_instance": mock_ti}

with pytest.raises(WaiterError) as exc:
self.ecs.execute(mock_context)

# Original exception must be preserved.
assert exc.value is waiter_error

# Cleanup attempted despite failure.
client_mock.stop_task.assert_called_once()


class TestEcsCreateClusterOperator(EcsBaseTestCase):
@pytest.mark.parametrize(("waiter_delay", "waiter_max_attempts"), WAITERS_TEST_CASES)
Expand Down