Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Private async execution utilities."""

import asyncio
import contextvars
from concurrent.futures import ThreadPoolExecutor
from typing import Awaitable, Callable, TypeVar

Expand All @@ -27,5 +28,6 @@ def execute() -> T:
return asyncio.run(execute_async())

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
context = contextvars.copy_context()
future = executor.submit(context.run, execute)
return future.result()
105 changes: 51 additions & 54 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
HookRegistry,
MessageAddedEvent,
)
from ..interrupt import _InterruptState
from ..models.bedrock import BedrockModel
from ..models.model import Model
from ..session.session_manager import SessionManager
Expand All @@ -60,15 +61,13 @@
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException
from ..types.interrupt import InterruptResponseContent
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
ConversationManager,
SlidingWindowConversationManager,
)
from .interrupt import InterruptState
from .state import AgentState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,8 +178,8 @@ def __init__(
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []
# initializing self.system_prompt for backwards compatibility
self.system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt)
# initializing self._system_prompt for backwards compatibility
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt)
self._default_structured_output_model = structured_output_model
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
Expand Down Expand Up @@ -243,7 +242,7 @@ def __init__(

self.hooks = HookRegistry()

self._interrupt_state = InterruptState()
self._interrupt_state = _InterruptState()

# Initialize session management functionality
self._session_manager = session_manager
Expand All @@ -257,6 +256,35 @@ def __init__(
self.hooks.add_hook(hook)
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))

@property
def system_prompt(self) -> str | None:
"""Get the system prompt as a string for backwards compatibility.

Returns the system prompt as a concatenated string when it contains text content,
or None if no text content is present. This maintains backwards compatibility
with existing code that expects system_prompt to be a string.

Returns:
The system prompt as a string, or None if no text content exists.
"""
return self._system_prompt

@system_prompt.setter
def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None:
"""Set the system prompt and update internal content representation.

Accepts either a string or list of SystemContentBlock objects.
When set, both the backwards-compatible string representation and the internal
content block representation are updated to maintain consistency.

Args:
value: System prompt as string, list of SystemContentBlock objects, or None.
- str: Simple text prompt (most common use case)
- list[SystemContentBlock]: Content blocks with features like caching
- None: Clear the system prompt
"""
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value)

@property
def tool(self) -> _ToolCaller:
"""Call tool as a function.
Expand Down Expand Up @@ -424,15 +452,15 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
category=DeprecationWarning,
stacklevel=2,
)
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
with self.tracer.tracer.start_as_current_span(
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
) as structured_output_span:
try:
if not self.messages and not prompt:
raise ValueError("No conversation history or prompt provided")

temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt)
temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt)

structured_output_span.set_attributes(
{
Expand Down Expand Up @@ -465,7 +493,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
return event["output"]

finally:
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))

def cleanup(self) -> None:
"""Clean up resources used by the agent.
Expand Down Expand Up @@ -531,7 +559,7 @@ async def stream_async(
yield event["data"]
```
"""
self._resume_interrupt(prompt)
self._interrupt_state.resume(prompt)

merged_state = {}
if kwargs:
Expand All @@ -548,7 +576,7 @@ async def stream_async(
callback_handler = kwargs.get("callback_handler", self.callback_handler)

# Process input and get message to add (if any)
messages = self._convert_prompt_to_messages(prompt)
messages = await self._convert_prompt_to_messages(prompt)

self.trace_span = self._start_agent_trace_span(messages)

Expand All @@ -574,38 +602,6 @@ async def stream_async(
self._end_agent_trace_span(error=e)
raise

def _resume_interrupt(self, prompt: AgentInput) -> None:
"""Configure the interrupt state if resuming from an interrupt event.

Args:
prompt: User responses if resuming from interrupt.

Raises:
TypeError: If in interrupt state but user did not provide responses.
"""
if not self._interrupt_state.activated:
return

if not isinstance(prompt, list):
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")

invalid_types = [
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
]
if invalid_types:
raise TypeError(
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
)

for content in cast(list[InterruptResponseContent], prompt):
interrupt_id = content["interruptResponse"]["interruptId"]
interrupt_response = content["interruptResponse"]["response"]

if interrupt_id not in self._interrupt_state.interrupts:
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")

self._interrupt_state.interrupts[interrupt_id].response = interrupt_response

async def _run_loop(
self,
messages: Messages,
Expand All @@ -622,13 +618,13 @@ async def _run_loop(
Yields:
Events from the event loop cycle.
"""
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))

try:
yield InitEventLoopEvent()

for message in messages:
self._append_message(message)
await self._append_message(message)

structured_output_context = StructuredOutputContext(
structured_output_model or self._default_structured_output_model
Expand All @@ -654,7 +650,7 @@ async def _run_loop(

finally:
self.conversation_manager.apply_management(self)
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self))

async def _execute_event_loop_cycle(
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
Expand Down Expand Up @@ -703,7 +699,7 @@ async def _execute_event_loop_cycle(
if structured_output_context:
structured_output_context.cleanup(self.tool_registry)

def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if self._interrupt_state.activated:
return []

Expand All @@ -718,7 +714,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
tool_use_ids = [
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
]
self._append_message(
await self._append_message(
{
"role": "user",
"content": generate_missing_tool_result_content(tool_use_ids),
Expand Down Expand Up @@ -749,7 +745,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
return messages

def _record_tool_execution(
async def _record_tool_execution(
self,
tool: ToolUse,
tool_result: ToolResult,
Expand Down Expand Up @@ -809,10 +805,10 @@ def _record_tool_execution(
}

# Add to message history
self._append_message(user_msg)
self._append_message(tool_use_msg)
self._append_message(tool_result_msg)
self._append_message(assistant_msg)
await self._append_message(user_msg)
await self._append_message(tool_use_msg)
await self._append_message(tool_result_msg)
await self._append_message(assistant_msg)

def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
"""Starts a trace span for the agent.
Expand All @@ -828,6 +824,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
tools=self.tool_names,
system_prompt=self.system_prompt,
custom_trace_attributes=self.trace_attributes,
tools_config=self.tool_registry.get_all_tools_config(),
)

def _end_agent_trace_span(
Expand Down Expand Up @@ -897,10 +894,10 @@ def _initialize_system_prompt(
else:
return None, None

def _append_message(self, message: Message) -> None:
async def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))

def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
"""Redact user content preserving toolResult blocks.
Expand Down
59 changes: 0 additions & 59 deletions src/strands/agent/interrupt.py

This file was deleted.

12 changes: 6 additions & 6 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def event_loop_cycle(
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
agent._append_message(
await agent._append_message(
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
)

Expand Down Expand Up @@ -322,7 +322,7 @@ async def _handle_model_execution(
model_id=model_id,
)
with trace_api.use_span(model_invoke_span):
agent.hooks.invoke_callbacks(
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
)
Expand All @@ -347,7 +347,7 @@ async def _handle_model_execution(
stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})

agent.hooks.invoke_callbacks(
await agent.hooks.invoke_callbacks_async(
AfterModelCallEvent(
agent=agent,
stop_response=AfterModelCallEvent.ModelStopResponse(
Expand All @@ -368,7 +368,7 @@ async def _handle_model_execution(
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)

agent.hooks.invoke_callbacks(
await agent.hooks.invoke_callbacks_async(
AfterModelCallEvent(
agent=agent,
exception=e,
Expand Down Expand Up @@ -402,7 +402,7 @@ async def _handle_model_execution(

# Add the response message to the conversation
agent.messages.append(message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message))

# Update metrics
agent.event_loop_metrics.update_usage(usage)
Expand Down Expand Up @@ -507,7 +507,7 @@ async def _handle_tool_execution(
}

agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message))

yield ToolResultMessageEvent(message=tool_result_message)

Expand Down
7 changes: 5 additions & 2 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,11 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non
Returns:
The extracted usage metrics and latency.
"""
usage = Usage(**event["usage"])
metrics = Metrics(**event["metrics"])
# MetadataEvent has total=False, making all fields optional, but Usage and Metrics types
# have Required fields. Provide defaults to handle cases where custom models don't
# provide usage/metrics (e.g., when latency info is unavailable).
usage = Usage(**{"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, **event.get("usage", {})})
metrics = Metrics(**{"latencyMs": 0, **event.get("metrics", {})})
if time_to_first_byte_ms:
metrics["timeToFirstByteMs"] = time_to_first_byte_ms

Expand Down
Loading
Loading