Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,31 +218,46 @@ def execute(self, context: Context):
MaxCount=self.max_count,
**self.config,
)["Instances"]

instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances]
# Console link is for EC2 dashboard list, not individual instances when more than 1 instance

EC2InstanceDashboardLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids),
)
for instance_id in instance_ids:
self.log.info("Created EC2 instance %s", instance_id)

if self.wait_for_completion:
self.hook.get_waiter("instance_running").wait(
InstanceIds=[instance_id],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
try:
instance_ids = self._on_kill_instance_ids = [instance["InstanceId"] for instance in instances]
# Console link is for EC2 dashboard list, not individual instances when more than 1 instance

EC2InstanceDashboardLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids),
)
for instance_id in instance_ids:
self.log.info("Created EC2 instance %s", instance_id)

if self.wait_for_completion:
self.hook.get_waiter("instance_running").wait(
InstanceIds=[instance_id],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)

# leave "_on_kill_instance_ids" in place for finishing post-processing
return instance_ids

# Best-effort cleanup when post-creation steps fail (e.g. IAM/permission errors).
except Exception:
self.log.exception(
"Exception after EC2 instance creation; attempting cleanup for instances %s",
instance_ids,
)
try:
self.hook.terminate_instances(instance_ids=instance_ids)
except Exception:
self.log.exception(
"Failed to cleanup EC2 instances %s after task failure",
instance_ids,
)

# leave "_on_kill_instance_ids" in place for finishing post-processing
return instance_ids
raise

def on_kill(self) -> None:
instance_ids = getattr(self, "_on_kill_instance_ids", [])
Expand Down
77 changes: 77 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest
from botocore.exceptions import ClientError, WaiterError
from moto import mock_aws

from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
Expand Down Expand Up @@ -96,6 +99,80 @@ def test_template_fields(self):
)
validate_template_fields(ec2_operator)

@mock_aws
def test_cleanup_on_post_create_failure(self):
ec2_hook = EC2Hook()

operator = EC2CreateInstanceOperator(
task_id="test_cleanup_on_error",
image_id=self._get_image_id(ec2_hook),
wait_for_completion=True,
)

waiter_error = WaiterError(
"InstanceRunning",
"You are not authorized to perform this operation",
{},
)

# Force failure after instance creation (e.g. missing DescribeInstances permission).
with mock.patch.object(operator.hook, "get_waiter") as mock_get_waiter:
mock_get_waiter.return_value.wait.side_effect = waiter_error
with pytest.raises(WaiterError) as exc:
operator.execute(None)

# Ensure the original waiter exception is propagated unchanged.
assert exc.value is waiter_error

# Instance must have been terminated.
# We know exactly one instance was created.
instances = list(ec2_hook.conn.instances.all())
assert len(instances) == 1

instance = instances[0]
assert instance.state["Name"] == "terminated"

@mock_aws
def test_cleanup_failure_propagates_original_exception(self):
ec2_hook = EC2Hook()

operator = EC2CreateInstanceOperator(
task_id="test_cleanup_failure_does_not_mask_error",
image_id=self._get_image_id(ec2_hook),
wait_for_completion=True,
)

waiter_error = WaiterError(
"InstanceRunning",
"You are not authorized to perform this operation",
{},
)

cleanup_error = ClientError(
error_response={
"Error": {
"Code": "UnauthorizedOperation",
"Message": "You are not authorized to perform this operation",
}
},
operation_name="TerminateInstances",
)

with (
mock.patch.object(operator.hook, "get_waiter") as mock_get_waiter,
mock.patch.object(operator.hook, "terminate_instances") as mock_terminate,
):
mock_get_waiter.return_value.wait.side_effect = waiter_error
mock_terminate.side_effect = cleanup_error

with pytest.raises(WaiterError) as exc:
operator.execute(None)

# Ensure the original waiter exception is propagated unchanged.
assert exc.value is waiter_error

# Cleanup is best-effort; failure to terminate must not override the original error.


class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down