Skip to content

Commit 5ca7b93

Browse files
authored
Merge branch 'main' into agent-kwargs
2 parents 53ebaa0 + 0bd01d6 commit 5ca7b93

35 files changed

+2479
-912
lines changed

pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ litellm = [
7070
llamaapi = [
7171
"llama-api-client>=0.1.0,<1.0.0",
7272
]
73+
mistral = [
74+
"mistralai>=1.8.2",
75+
]
7376
ollama = [
7477
"ollama>=0.4.8,<1.0.0",
7578
]
@@ -92,7 +95,7 @@ a2a = [
9295
source = "vcs"
9396

9497
[tool.hatch.envs.hatch-static-analysis]
95-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
98+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
9699
dependencies = [
97100
"mypy>=1.15.0,<2.0.0",
98101
"ruff>=0.11.6,<0.12.0",
@@ -116,7 +119,7 @@ lint-fix = [
116119
]
117120

118121
[tool.hatch.envs.hatch-test]
119-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
122+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel","mistral"]
120123
extra-dependencies = [
121124
"moto>=5.1.0,<6.0.0",
122125
"pytest>=8.0.0,<9.0.0",
@@ -132,7 +135,7 @@ extra-args = [
132135

133136
[tool.hatch.envs.dev]
134137
dev-mode = true
135-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"]
138+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel","mistral"]
136139

137140
[tool.hatch.envs.a2a]
138141
dev-mode = true

src/strands/agent/agent.py

Lines changed: 50 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,18 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import asyncio
1312
import json
1413
import logging
1514
import os
1615
import random
1716
from concurrent.futures import ThreadPoolExecutor
18-
from threading import Thread
19-
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
20-
from uuid import uuid4
17+
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
2118

2219
from opentelemetry import trace
2320
from pydantic import BaseModel
2421

2522
from ..event_loop.event_loop import event_loop_cycle
26-
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
23+
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2724
from ..handlers.tool_handler import AgentToolHandler
2825
from ..models.bedrock import BedrockModel
2926
from ..telemetry.metrics import EventLoopMetrics
@@ -183,7 +180,7 @@ def __init__(
183180
self,
184181
model: Union[Model, str, None] = None,
185182
messages: Optional[Messages] = None,
186-
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
183+
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
187184
system_prompt: Optional[str] = None,
188185
callback_handler: Optional[
189186
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
@@ -255,7 +252,7 @@ def __init__(
255252
self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager()
256253

257254
# Process trace attributes to ensure they're of compatible types
258-
self.trace_attributes: Dict[str, AttributeValue] = {}
255+
self.trace_attributes: dict[str, AttributeValue] = {}
259256
if trace_attributes:
260257
for k, v in trace_attributes.items():
261258
if isinstance(v, (str, int, float, bool)) or (
@@ -312,7 +309,7 @@ def tool(self) -> ToolCaller:
312309
return self.tool_caller
313310

314311
@property
315-
def tool_names(self) -> List[str]:
312+
def tool_names(self) -> list[str]:
316313
"""Get a list of all registered tool names.
317314
318315
Returns:
@@ -357,19 +354,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
357354
- metrics: Performance metrics from the event loop
358355
- state: The final state of the event loop
359356
"""
357+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
358+
360359
self._start_agent_trace_span(prompt)
361360

362361
try:
363-
# Run the event loop and get the result
364-
result = self._run_loop(prompt, kwargs)
362+
events = self._run_loop(callback_handler, prompt, kwargs)
363+
for event in events:
364+
if "callback" in event:
365+
callback_handler(**event["callback"])
366+
367+
stop_reason, message, metrics, state = event["stop"]
368+
result = AgentResult(stop_reason, message, metrics, state)
365369

366370
self._end_agent_trace_span(response=result)
367371

368372
return result
373+
369374
except Exception as e:
370375
self._end_agent_trace_span(error=e)
371-
372-
# Re-raise the exception to preserve original behavior
373376
raise
374377

375378
def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T:
@@ -383,9 +386,9 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
383386
instruct the model to output the structured data.
384387
385388
Args:
386-
output_model(Type[BaseModel]): The output model (a JSON schema written as a Pydantic BaseModel)
389+
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
387390
that the agent will use when responding.
388-
prompt(Optional[str]): The prompt to use for the agent.
391+
prompt: The prompt to use for the agent.
389392
"""
390393
messages = self.messages
391394
if not messages and not prompt:
@@ -396,7 +399,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
396399
messages.append({"role": "user", "content": [{"text": prompt}]})
397400

398401
# get the structured output from the model
399-
return self.model.structured_output(output_model, messages, self.callback_handler)
402+
events = self.model.structured_output(output_model, messages)
403+
for event in events:
404+
if "callback" in event:
405+
self.callback_handler(**cast(dict, event["callback"]))
406+
407+
return event["output"]
400408

401409
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
402410
"""Process a natural language prompt and yield events as an async iterator.
@@ -428,94 +436,63 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
428436
yield event["data"]
429437
```
430438
"""
431-
self._start_agent_trace_span(prompt)
439+
callback_handler = kwargs.get("callback_handler", self.callback_handler)
432440

433-
_stop_event = uuid4()
434-
435-
queue = asyncio.Queue[Any]()
436-
loop = asyncio.get_event_loop()
437-
438-
def enqueue(an_item: Any) -> None:
439-
nonlocal queue
440-
nonlocal loop
441-
loop.call_soon_threadsafe(queue.put_nowait, an_item)
442-
443-
def queuing_callback_handler(**handler_kwargs: Any) -> None:
444-
enqueue(handler_kwargs.copy())
441+
self._start_agent_trace_span(prompt)
445442

446-
def target_callback() -> None:
447-
nonlocal kwargs
443+
try:
444+
events = self._run_loop(callback_handler, prompt, kwargs)
445+
for event in events:
446+
if "callback" in event:
447+
callback_handler(**event["callback"])
448+
yield event["callback"]
448449

449-
try:
450-
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
451-
self._end_agent_trace_span(response=result)
452-
except Exception as e:
453-
self._end_agent_trace_span(error=e)
454-
enqueue(e)
455-
finally:
456-
enqueue(_stop_event)
450+
stop_reason, message, metrics, state = event["stop"]
451+
result = AgentResult(stop_reason, message, metrics, state)
457452

458-
thread = Thread(target=target_callback, daemon=True)
459-
thread.start()
453+
self._end_agent_trace_span(response=result)
460454

461-
try:
462-
while True:
463-
item = await queue.get()
464-
if item == _stop_event:
465-
break
466-
if isinstance(item, Exception):
467-
raise item
468-
yield item
469-
finally:
470-
thread.join()
455+
except Exception as e:
456+
self._end_agent_trace_span(error=e)
457+
raise
471458

472459
def _run_loop(
473-
self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None
474-
) -> AgentResult:
460+
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
461+
) -> Generator[dict[str, Any], None, None]:
475462
"""Execute the agent's event loop with the given prompt and parameters."""
476463
try:
477-
# If the call had a callback_handler passed in, then for this event_loop
478-
# cycle we call both handlers as the callback_handler
479-
invocation_callback_handler = (
480-
CompositeCallbackHandler(self.callback_handler, supplementary_callback_handler)
481-
if supplementary_callback_handler is not None
482-
else self.callback_handler
483-
)
484-
485464
# Extract key parameters
486-
invocation_callback_handler(init_event_loop=True, **kwargs)
465+
yield {"callback": {"init_event_loop": True, **kwargs}}
487466

488467
# Set up the user message with optional knowledge base retrieval
489-
message_content: List[ContentBlock] = [{"text": prompt}]
468+
message_content: list[ContentBlock] = [{"text": prompt}]
490469
new_message: Message = {"role": "user", "content": message_content}
491470
self.messages.append(new_message)
492471

493472
# Execute the event loop cycle with retry logic for context limits
494-
return self._execute_event_loop_cycle(invocation_callback_handler, kwargs)
473+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
495474

496475
finally:
497476
self.conversation_manager.apply_management(self)
498477

499-
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]) -> AgentResult:
478+
def _execute_event_loop_cycle(
479+
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
480+
) -> Generator[dict[str, Any], None, None]:
500481
"""Execute the event loop cycle with retry logic for context window limits.
501482
502483
This internal method handles the execution of the event loop cycle and implements
503484
retry logic for handling context window overflow exceptions by reducing the
504485
conversation context and retrying.
505486
506-
Args:
507-
callback_handler: The callback handler to use for events.
508-
kwargs: Additional parameters to pass through event loop.
509-
510-
Returns:
511-
The result of the event loop cycle.
487+
Yields:
488+
Events of the loop cycle.
512489
"""
513490
# Add `Agent` to kwargs to keep backwards-compatibility
514491
kwargs["agent"] = self
515492

516493
try:
517494
# Execute the main event loop cycle
518-
events = event_loop_cycle(
495+
yield from event_loop_cycle(
519496
model=self.model,
520497
system_prompt=self.system_prompt,
521498
messages=self.messages, # will be modified by event_loop_cycle
@@ -527,19 +504,11 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
527504
event_loop_parent_span=self.trace_span,
528505
kwargs=kwargs,
529506
)
530-
for event in events:
531-
if "callback" in event:
532-
callback_handler(**event["callback"])
533-
534-
stop_reason, message, metrics, state = event["stop"]
535-
536-
return AgentResult(stop_reason, message, metrics, state)
537507

538508
except ContextWindowOverflowException as e:
539509
# Try reducing the context size and retrying
540-
541510
self.conversation_manager.reduce_context(self, e=e)
542-
return self._execute_event_loop_cycle(callback_handler, kwargs)
511+
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
543512

544513
def _record_tool_execution(
545514
self,
@@ -625,7 +594,7 @@ def _end_agent_trace_span(
625594
error: Error to record as a trace attribute.
626595
"""
627596
if self.trace_span:
628-
trace_attributes: Dict[str, Any] = {
597+
trace_attributes: dict[str, Any] = {
629598
"span": self.trace_span,
630599
}
631600

src/strands/agent/conversation_manager/null_conversation_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def apply_management(self, _agent: "Agent") -> None:
2323
"""Does nothing to the conversation history.
2424
2525
Args:
26-
agent: The agent whose conversation history will remain unmodified.
26+
_agent: The agent whose conversation history will remain unmodified.
2727
"""
2828
pass
2929

3030
def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None:
3131
"""Does not reduce context and raises an exception.
3232
3333
Args:
34-
agent: The agent whose conversation history will remain unmodified.
34+
_agent: The agent whose conversation history will remain unmodified.
3535
e: The exception that triggered the context reduction, if any.
3636
3737
Raises:

src/strands/models/anthropic.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast
10+
from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast
1111

1212
import anthropic
1313
from pydantic import BaseModel
1414
from typing_extensions import Required, Unpack, override
1515

1616
from ..event_loop.streaming import process_stream
17-
from ..handlers.callback_handler import PrintingCallbackHandler
1817
from ..tools import convert_pydantic_to_tool_spec
1918
from ..types.content import ContentBlock, Messages
2019
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
@@ -378,24 +377,24 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
378377

379378
@override
380379
def structured_output(
381-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
382-
) -> T:
380+
self, output_model: Type[T], prompt: Messages
381+
) -> Generator[dict[str, Union[T, Any]], None, None]:
383382
"""Get structured output from the model.
384383
385384
Args:
386-
output_model(Type[BaseModel]): The output model to use for the agent.
387-
prompt(Messages): The prompt messages to use for the agent.
388-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
385+
output_model: The output model to use for the agent.
386+
prompt: The prompt messages to use for the agent.
387+
388+
Yields:
389+
Model events with the last being the structured output.
389390
"""
390-
callback_handler = callback_handler or PrintingCallbackHandler()
391391
tool_spec = convert_pydantic_to_tool_spec(output_model)
392392

393393
response = self.converse(messages=prompt, tool_specs=[tool_spec])
394394
for event in process_stream(response, prompt):
395-
if "callback" in event:
396-
callback_handler(**event["callback"])
397-
else:
398-
stop_reason, messages, _, _ = event["stop"]
395+
yield event
396+
397+
stop_reason, messages, _, _ = event["stop"]
399398

400399
if stop_reason != "tool_use":
401400
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
@@ -413,4 +412,4 @@ def structured_output(
413412
if output_response is None:
414413
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
415414

416-
return output_model(**output_response)
415+
yield {"output": output_model(**output_response)}

0 commit comments

Comments
 (0)