Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def hook(self) -> AwsGenericHook:

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.hook()
async with hook.async_conn as client:
response = client.list_command_invocations(CommandId=self.command_id)
async with await hook.get_async_conn() as client:
response = await client.list_command_invocations(CommandId=self.command_id)
instance_ids = [invocation["InstanceId"] for invocation in response.get("CommandInvocations", [])]
waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client)

Expand Down
2 changes: 1 addition & 1 deletion providers/amazon/tests/system/amazon/aws/example_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def wait_until_ssm_ready(instance_id: str, max_attempts: int = 10, delay_seconds

# [START howto_sensor_run_command]
await_run_command = SsmRunCommandCompletedSensor(
task_id="await_run_command", command_id=run_command.output
task_id="await_run_command", command_id="{{ ti.xcom_pull(task_ids='run_command') }}"
)
# [END howto_sensor_run_command]

Expand Down
30 changes: 16 additions & 14 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@

@pytest.fixture
def mock_ssm_list_invocations():
def _setup(mock_async_conn):
def _setup(mock_get_async_conn):
mock_client = mock.MagicMock()
mock_async_conn.__aenter__.return_value = mock_client
mock_client.list_command_invocations.return_value = {
"CommandInvocations": [
{"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
{"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
]
}
mock_get_async_conn.return_value.__aenter__.return_value = mock_client
mock_client.list_command_invocations = mock.AsyncMock(
return_value={
"CommandInvocations": [
{"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_1},
{"CommandId": COMMAND_ID, "InstanceId": INSTANCE_ID_2},
]
}
)
return mock_client

return _setup
Expand All @@ -60,10 +62,10 @@ def test_serialization(self):
assert kwargs.get("command_id") == COMMAND_ID

@pytest.mark.asyncio
@mock.patch.object(SsmHook, "async_conn")
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_run_success(self, mock_get_waiter, mock_async_conn, mock_ssm_list_invocations):
mock_client = mock_ssm_list_invocations(mock_async_conn)
async def test_run_success(self, mock_get_waiter, mock_get_async_conn, mock_ssm_list_invocations):
mock_client = mock_ssm_list_invocations(mock_get_async_conn)
mock_get_waiter().wait = mock.AsyncMock(name="wait")

trigger = SsmRunCommandTrigger(command_id=COMMAND_ID)
Expand All @@ -82,10 +84,10 @@ async def test_run_success(self, mock_get_waiter, mock_async_conn, mock_ssm_list
mock_client.list_command_invocations.assert_called_once_with(CommandId=COMMAND_ID)

@pytest.mark.asyncio
@mock.patch.object(SsmHook, "async_conn")
@mock.patch.object(SsmHook, "get_async_conn")
@mock.patch.object(SsmHook, "get_waiter")
async def test_run_fails(self, mock_get_waiter, mock_async_conn, mock_ssm_list_invocations):
mock_ssm_list_invocations(mock_async_conn)
async def test_run_fails(self, mock_get_waiter, mock_get_async_conn, mock_ssm_list_invocations):
mock_ssm_list_invocations(mock_get_async_conn)
mock_get_waiter().wait.side_effect = WaiterError(
"name", "terminal failure", {"CommandInvocations": [{"CommandId": COMMAND_ID}]}
)
Expand Down