Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ def __init__(
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
event = validate_execute_complete_event(event)

if event["status"] == "failed":
# Command failed - raise an exception with detailed information
command_status = event.get("command_status", "Unknown")
exit_code = event.get("exit_code", -1)
instance_id = event.get("instance_id", "Unknown")
message = event.get("message", "Command failed")

error_msg = (
f"SSM run command {event['command_id']} failed on instance {instance_id}. "
f"Status: {command_status}, Exit code: {exit_code}. {message}"
)
raise RuntimeError(error_msg)

if event["status"] != "success":
raise RuntimeError(f"Error while running run command: {event}")

Expand Down
53 changes: 39 additions & 14 deletions providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
waiter_args={"CommandId": command_id},
failure_message="SSM run command failed.",
status_message="Status of SSM run command is",
status_queries=["status"],
status_queries=["Status"],
return_key="command_id",
return_value=command_id,
waiter_delay=waiter_delay,
Expand Down Expand Up @@ -105,19 +105,26 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.status_queries,
)
except Exception:
if not self.fail_on_nonzero_exit:
# Enhanced mode: check if it's an AWS-level failure
invocation = await client.get_command_invocation(
CommandId=self.command_id, InstanceId=instance_id
)
status = invocation.get("Status", "")
# Get detailed invocation information to determine failure type
invocation = await client.get_command_invocation(
CommandId=self.command_id, InstanceId=instance_id
)
status = invocation.get("Status", "")
response_code = invocation.get("ResponseCode", -1)

# AWS-level failures should always raise
if SsmHook.is_aws_level_failure(status):
raise
# AWS-level failures should always raise
if SsmHook.is_aws_level_failure(status):
self.log.error(
"AWS-level failure for command %s on instance %s: status=%s",
self.command_id,
instance_id,
status,
)
raise

# Command-level failure - tolerate it in enhanced mode
response_code = invocation.get("ResponseCode", "unknown")
# Command-level failure (non-zero exit code)
if not self.fail_on_nonzero_exit:
# Enhanced mode: tolerate command-level failures
self.log.info(
"Command %s completed with status %s (exit code: %s) for instance %s. "
"Continuing due to fail_on_nonzero_exit=False",
Expand All @@ -128,7 +135,25 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
)
continue
else:
# Traditional mode: all failures raise
raise
# Traditional mode: yield failure event instead of raising
# This allows the operator to handle the failure gracefully
self.log.warning(
"Command %s failed with status %s (exit code: %s) for instance %s",
self.command_id,
status,
response_code,
instance_id,
)
yield TriggerEvent(
{
"status": "failed",
"message": f"Command failed with status {status} (exit code: {response_code})",
"command_status": status,
"exit_code": response_code,
"instance_id": instance_id,
self.return_key: self.return_value,
}
)
return

yield TriggerEvent({"status": "success", self.return_key: self.return_value})
55 changes: 55 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,61 @@ def test_operator_passes_parameter_to_trigger(self, mock_trigger_class, mock_con
assert call_kwargs["command_id"] == COMMAND_ID
assert call_kwargs["fail_on_nonzero_exit"] is False

def test_execute_complete_success(self):
"""Test execute_complete with successful event."""
event = {"status": "success", "command_id": COMMAND_ID}

result = self.operator.execute_complete({}, event)

assert result == COMMAND_ID

def test_execute_complete_failure_event(self):
"""Test execute_complete with failure event from trigger."""
event = {
"status": "failed",
"command_id": COMMAND_ID,
"command_status": "Failed",
"exit_code": 1,
"instance_id": "i-123456",
"message": "Command failed with status Failed (exit code: 1)",
}

with pytest.raises(RuntimeError) as exc_info:
self.operator.execute_complete({}, event)

error_msg = str(exc_info.value)
assert COMMAND_ID in error_msg
assert "Failed" in error_msg
assert "exit code: 1" in error_msg
assert "i-123456" in error_msg

def test_execute_complete_failure_event_with_different_exit_codes(self):
"""Test execute_complete properly reports different exit codes in error messages."""
event = {
"status": "failed",
"command_id": COMMAND_ID,
"command_status": "Failed",
"exit_code": 42,
"instance_id": "i-789012",
"message": "Command failed with status Failed (exit code: 42)",
}

with pytest.raises(RuntimeError) as exc_info:
self.operator.execute_complete({}, event)

error_msg = str(exc_info.value)
assert "exit code: 42" in error_msg
assert "i-789012" in error_msg

def test_execute_complete_unknown_status(self):
"""Test execute_complete with unknown status."""
event = {"status": "unknown", "command_id": COMMAND_ID}

with pytest.raises(RuntimeError) as exc_info:
self.operator.execute_complete({}, event)

assert "Error while running run command" in str(exc_info.value)


class TestSsmGetCommandInvocationOperator:
@pytest.fixture
Expand Down
89 changes: 87 additions & 2 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,14 @@ async def test_run_success(self, mock_get_waiter, mock_get_async_conn, mock_ssm_
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_run_fails(self, mock_get_waiter, mock_get_async_conn, mock_ssm_list_invocations):
mock_ssm_list_invocations(mock_get_async_conn)
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
mock_get_waiter().wait.side_effect = WaiterError(
"name", "terminal failure", {"CommandInvocations": [{"CommandId": COMMAND_ID}]}
)
# Mock get_command_invocation to return AWS-level failure
mock_client.get_command_invocation = mock.AsyncMock(
return_value={"Status": "TimedOut", "ResponseCode": -1}
)

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID)
generator = trigger.run()
Expand All @@ -124,8 +128,12 @@ async def test_trigger_default_fails_on_waiter_error(
self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations
):
"""Test traditional mode (fail_on_nonzero_exit=True) raises exception on waiter error."""
mock_ssm_list_invocations(mock_get_async_conn)
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
mock_async_wait.side_effect = AirflowException("SSM run command failed.")
# Mock get_command_invocation to return AWS-level failure
mock_client.get_command_invocation = mock.AsyncMock(
return_value={"Status": "Cancelled", "ResponseCode": -1}
)

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True)
generator = trigger.run()
Expand Down Expand Up @@ -204,3 +212,80 @@ def test_trigger_serialization_includes_parameter(self):
classpath, kwargs = trigger_default.serialize()

assert kwargs.get("fail_on_nonzero_exit") is True

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait")
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_trigger_yields_failure_event_instead_of_raising(
self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations
):
"""Test that trigger yields failure event instead of raising exception for command failures."""
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
# Mock async_wait to raise exception (simulating waiter failure)
mock_async_wait.side_effect = AirflowException("SSM run command failed.")
# Mock get_command_invocation to return Failed status with exit code 1
mock_client.get_command_invocation = mock.AsyncMock(
return_value={"Status": "Failed", "ResponseCode": 1}
)

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True)
generator = trigger.run()
response = await generator.asend(None)

# Should yield a failure event, not raise an exception
assert response.payload["status"] == "failed"
assert response.payload["command_id"] == COMMAND_ID
assert response.payload["exit_code"] == 1
assert response.payload["command_status"] == "Failed"
assert response.payload["instance_id"] == INSTANCE_ID_1
assert "Command failed with status Failed (exit code: 1)" in response.payload["message"]

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait")
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_trigger_yields_failure_event_for_different_exit_codes(
self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations
):
"""Test that trigger properly captures different exit codes in failure events."""
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
mock_async_wait.side_effect = AirflowException("SSM run command failed.")

# Test with exit code 2
mock_client.get_command_invocation = mock.AsyncMock(
return_value={"Status": "Failed", "ResponseCode": 2}
)

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True)
generator = trigger.run()
response = await generator.asend(None)

assert response.payload["status"] == "failed"
assert response.payload["exit_code"] == 2
assert response.payload["command_status"] == "Failed"

@pytest.mark.asyncio
@mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait")
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_trigger_continues_on_second_instance_after_first_fails(
self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations
):
"""Test that trigger stops after first failure and yields failure event."""
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
# First instance fails
mock_async_wait.side_effect = AirflowException("SSM run command failed.")
mock_client.get_command_invocation = mock.AsyncMock(
return_value={"Status": "Failed", "ResponseCode": 1}
)

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True)
generator = trigger.run()
response = await generator.asend(None)

# Should yield failure event for first instance
assert response.payload["status"] == "failed"
assert response.payload["instance_id"] == INSTANCE_ID_1
# Should only call get_command_invocation once (for first instance)
assert mock_client.get_command_invocation.call_count == 1