Skip to content

Commit 9f70298

Browse files
authored
feat(hooks): add AgentResult to AfterInvocationEvent (#1125)
The AfterInvocationEvent hook event did not provide access to the AgentResult, making it difficult for hooks to perform actions based on the outcome of an agent's invocation. This PR updates the AfterInvocationEvent to include an optional AgentResult. The Agent.invoke method now captures the AgentResult and passes it to the AfterInvocationEvent. New tests have been added to verify the functionality and immutability of the event.
1 parent e692133 commit 9f70298

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

src/strands/agent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from ..tools.registry import ToolRegistry
5757
from ..tools.structured_output._structured_output_context import StructuredOutputContext
5858
from ..tools.watcher import ToolWatcher
59-
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
59+
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
6060
from ..types.agent import AgentInput
6161
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
6262
from ..types.exceptions import ContextWindowOverflowException
@@ -621,6 +621,7 @@ async def _run_loop(
621621
"""
622622
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
623623

624+
agent_result: AgentResult | None = None
624625
try:
625626
yield InitEventLoopEvent()
626627

@@ -648,9 +649,13 @@ async def _run_loop(
648649
self._session_manager.redact_latest_message(self.messages[-1], self)
649650
yield event
650651

652+
# Capture the result from the final event if available
653+
if isinstance(event, EventLoopStopEvent):
654+
agent_result = AgentResult(*event["stop"])
655+
651656
finally:
652657
self.conversation_manager.apply_management(self)
653-
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))
658+
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result))
654659

655660
async def _execute_event_loop_cycle(
656661
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None

src/strands/hooks/events.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
import uuid
77
from dataclasses import dataclass
8-
from typing import Any, Optional
8+
from typing import TYPE_CHECKING, Any, Optional
99

1010
from typing_extensions import override
1111

12+
if TYPE_CHECKING:
13+
from ..agent.agent_result import AgentResult
14+
1215
from ..types.content import Message
1316
from ..types.interrupt import _Interruptible
1417
from ..types.streaming import StopReason
@@ -60,8 +63,15 @@ class AfterInvocationEvent(HookEvent):
6063
- Agent.__call__
6164
- Agent.stream_async
6265
- Agent.structured_output
66+
67+
Attributes:
68+
result: The result of the agent invocation, if available.
69+
This will be None when invoked from structured_output methods, as those return typed output directly rather
70+
than AgentResult.
6371
"""
6472

73+
result: "AgentResult | None" = None
74+
6575
@property
6676
def should_reverse_callbacks(self) -> bool:
6777
"""True to invoke callbacks in reverse order."""

tests/strands/agent/hooks/test_events.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from strands.agent.agent_result import AgentResult
56
from strands.hooks import (
67
AfterInvocationEvent,
78
AfterToolCallEvent,
@@ -10,6 +11,7 @@
1011
BeforeToolCallEvent,
1112
MessageAddedEvent,
1213
)
14+
from strands.types.content import Message
1315
from strands.types.tools import ToolResult, ToolUse
1416

1517

@@ -138,3 +140,22 @@ def test_after_tool_invocation_event_cannot_write_properties(after_tool_event):
138140
after_tool_event.invocation_state = {}
139141
with pytest.raises(AttributeError, match="Property exception is not writable"):
140142
after_tool_event.exception = Exception("test")
143+
144+
145+
def test_after_invocation_event_properties_not_writable(agent):
146+
"""Test that properties are not writable after initialization."""
147+
mock_message: Message = {"role": "assistant", "content": [{"text": "test"}]}
148+
mock_result = AgentResult(
149+
stop_reason="end_turn",
150+
message=mock_message,
151+
metrics={},
152+
state={},
153+
)
154+
155+
event = AfterInvocationEvent(agent=agent, result=None)
156+
157+
with pytest.raises(AttributeError, match="Property result is not writable"):
158+
event.result = mock_result
159+
160+
with pytest.raises(AttributeError, match="Property agent is not writable"):
161+
event.agent = Mock()

tests/strands/agent/test_agent_hooks.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_agent_tool_call(agent, hook_provider, agent_tool):
147147
def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use):
148148
"""Verify that the correct hook events are emitted as part of __call__."""
149149

150-
agent("test message")
150+
result = agent("test message")
151151

152152
length, events = hook_provider.get_events()
153153

@@ -197,7 +197,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
197197
)
198198
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
199199

200-
assert next(events) == AfterInvocationEvent(agent=agent)
200+
assert next(events) == AfterInvocationEvent(agent=agent, result=result)
201201

202202
assert len(agent.messages) == 4
203203

@@ -210,8 +210,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
210210
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)]
211211

212212
# iterate the rest
213-
async for _ in iterator:
214-
pass
213+
result = None
214+
async for item in iterator:
215+
if "result" in item:
216+
result = item["result"]
215217

216218
length, events = hook_provider.get_events()
217219

@@ -261,7 +263,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
261263
)
262264
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])
263265

264-
assert next(events) == AfterInvocationEvent(agent=agent)
266+
assert next(events) == AfterInvocationEvent(agent=agent, result=result)
265267

266268
assert len(agent.messages) == 4
267269

0 commit comments

Comments
 (0)