Skip to content

Add Model Invocation Hooks #387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast

from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ..experimental.hooks.events import MessageAddedEvent
from ..experimental.hooks.registry import get_registry
from ..experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
MessageAddedEvent,
get_registry,
)
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
Expand Down Expand Up @@ -115,6 +120,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener

tool_specs = agent.tool_registry.get_all_tool_specs()

get_registry(agent).invoke_callbacks(
BeforeModelInvocationEvent(
agent=agent,
)
)

try:
# TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding
# to the callback handler. This will be revisited when migrating to strongly typed events.
Expand All @@ -125,40 +136,50 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
stop_reason, message, usage, metrics = event["stop"]
kwargs.setdefault("request_state", {})

get_registry(agent).invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
stop_reason=stop_reason,
message=message,
),
)
)

if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
break # Success! Break out of retry loop

except ContextWindowOverflowException as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise e

except ModelThrottledException as e:
except Exception as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)

if attempt + 1 == MAX_ATTEMPTS:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
raise e

logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
get_registry(agent).invoke_callbacks(
AfterModelInvocationEvent(
agent=agent,
exception=e,
)
)
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}}
if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
raise e

except Exception as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)
raise e
logger.debug(
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
attempt + 1,
)
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}}
else:
raise e

try:
# Add message in trace and mark the end of the stream messages trace
Expand Down
4 changes: 4 additions & 0 deletions src/strands/experimental/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def log_end(self, event: EndRequestEvent) -> None:
"""

from .events import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
AgentInitializedEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
MessageAddedEvent,
Expand All @@ -43,6 +45,8 @@ def log_end(self, event: EndRequestEvent) -> None:
"AgentInitializedEvent",
"StartRequestEvent",
"EndRequestEvent",
"BeforeModelInvocationEvent",
"AfterModelInvocationEvent",
"BeforeToolInvocationEvent",
"AfterToolInvocationEvent",
"MessageAddedEvent",
Expand Down
54 changes: 54 additions & 0 deletions src/strands/experimental/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Optional

from ...types.content import Message
from ...types.streaming import StopReason
from ...types.tools import AgentTool, ToolResult, ToolUse
from .registry import HookEvent

Expand Down Expand Up @@ -121,6 +122,59 @@ def should_reverse_callbacks(self) -> bool:
return True


@dataclass
class BeforeModelInvocationEvent(HookEvent):
"""Event triggered before the model is invoked.

This event is fired just before the agent calls the model for inference,
allowing hook providers to inspect or modify the messages and configuration
that will be sent to the model.

Note: This event is not fired for invocations to structured_output.
"""

pass


@dataclass
class AfterModelInvocationEvent(HookEvent):
"""Event triggered after the model invocation completes.

This event is fired after the agent has finished calling the model,
regardless of whether the invocation was successful or resulted in an error.
Hook providers can use this event for cleanup, logging, or post-processing.

Note: This event uses reverse callback ordering, meaning callbacks registered
later will be invoked first during cleanup.

Note: This event is not fired for invocations to structured_output.

Attributes:
stop_response: The model response data if invocation was successful, None if failed.
exception: Exception if the model invocation failed, None if successful.
"""

@dataclass
class ModelStopResponse:
"""Model response data from successful invocation.

Attributes:
stop_reason: The reason the model stopped generating.
message: The generated message from the model.
"""

message: Message
stop_reason: StopReason

stop_response: Optional[ModelStopResponse] = None
exception: Optional[Exception] = None

@property
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True


@dataclass
class MessageAddedEvent(HookEvent):
"""Event triggered when a message is added to the agent's conversation.
Expand Down
20 changes: 20 additions & 0 deletions src/strands/experimental/hooks/rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Hook System Rules

## Terminology

- **Paired events**: Events that denote the beginning and end of an operation
- **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response

## Naming Conventions

- All hook events have a suffix of `Event`
- Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event`

## Paired Events

- The final event in a pair returns `True` for `should_reverse_callbacks`
- For every `Before` event there is a corresponding `After` event, even if an exception occurs

## Writable Properties

For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call.
63 changes: 59 additions & 4 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import strands
from strands import Agent
from strands.experimental.hooks import (
AfterModelInvocationEvent,
AfterToolInvocationEvent,
AgentInitializedEvent,
BeforeModelInvocationEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
MessageAddedEvent,
Expand All @@ -29,6 +31,8 @@ def hook_provider():
EndRequestEvent,
AfterToolInvocationEvent,
BeforeToolInvocationEvent,
BeforeModelInvocationEvent,
AfterModelInvocationEvent,
MessageAddedEvent,
]
)
Expand Down Expand Up @@ -84,6 +88,11 @@ def assert_message_is_last_message_added(event: MessageAddedEvent):
return agent


@pytest.fixture
def tools_config(agent):
return agent.tool_config["tools"]


@pytest.fixture
def user():
class User(BaseModel):
Expand Down Expand Up @@ -131,20 +140,33 @@ def test_agent_tool_call(agent, hook_provider, agent_tool):
assert len(agent.messages) == 4


def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use):
"""Verify that the correct hook events are emitted as part of __call__."""

agent("test message")

length, events = hook_provider.get_events()

assert length == 8
assert length == 12

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == BeforeModelInvocationEvent(agent=agent)
assert next(events) == AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
message={
"content": [{"toolUse": tool_use}],
"role": "assistant",
},
stop_reason="tool_use",
),
exception=None,
)

assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
assert next(events) == BeforeToolInvocationEvent(
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
Expand All @@ -157,14 +179,24 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == BeforeModelInvocationEvent(agent=agent)
assert next(events) == AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
message=mock_model.agent_responses[1],
stop_reason="end_turn",
),
exception=None,
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])

assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 4


@pytest.mark.asyncio
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_use):
async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_model, tool_use, agenerator):
"""Verify that the correct hook events are emitted as part of stream_async."""
iterator = agent.stream_async("test message")
await anext(iterator)
Expand All @@ -176,13 +208,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u

length, events = hook_provider.get_events()

assert length == 8
assert length == 12

assert next(events) == StartRequestEvent(agent=agent)
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
)
assert next(events) == BeforeModelInvocationEvent(agent=agent)
assert next(events) == AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
message={
"content": [{"toolUse": tool_use}],
"role": "assistant",
},
stop_reason="tool_use",
),
exception=None,
)

assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1])
assert next(events) == BeforeToolInvocationEvent(
agent=agent, selected_tool=agent_tool, tool_use=tool_use, kwargs=ANY
Expand All @@ -195,7 +240,17 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
assert next(events) == BeforeModelInvocationEvent(agent=agent)
assert next(events) == AfterModelInvocationEvent(
agent=agent,
stop_response=AfterModelInvocationEvent.ModelStopResponse(
message=mock_model.agent_responses[1],
stop_reason="end_turn",
),
exception=None,
)
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3])

assert next(events) == EndRequestEvent(agent=agent)

assert len(agent.messages) == 4
Expand Down
Loading