Skip to content

Commit c363032

Browse files
authored
Merge branch 'strands-agents:main' into session-persistence
2 parents 838ccaa + 5dc3f59 commit c363032

File tree

9 files changed

+128
-70
lines changed

9 files changed

+128
-70
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ writer = [
9090
]
9191

9292
a2a = [
93-
"a2a-sdk[sql]>=0.2.11",
93+
"a2a-sdk[sql]>=0.2.11,<1.0.0",
9494
"uvicorn>=0.34.2",
9595
"httpx>=0.28.1",
9696
"fastapi>=0.115.12",
@@ -136,7 +136,7 @@ all = [
136136
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
137137

138138
# a2a
139-
"a2a-sdk[sql]>=0.2.11",
139+
"a2a-sdk[sql]>=0.2.11,<1.0.0",
140140
"uvicorn>=0.34.2",
141141
"httpx>=0.28.1",
142142
"fastapi>=0.115.12",

src/strands/models/mistral.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,9 @@ def __init__(
9090

9191
logger.debug("config=<%s> | initializing", self.config)
9292

93-
client_args = client_args or {}
93+
self.client_args = client_args or {}
9494
if api_key:
95-
client_args["api_key"] = api_key
96-
97-
self.client = mistralai.Mistral(**client_args)
95+
self.client_args["api_key"] = api_key
9896

9997
@override
10098
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
@@ -421,67 +419,70 @@ async def stream(
421419
logger.debug("got response from model")
422420
if not self.config.get("stream", True):
423421
# Use non-streaming API
424-
response = await self.client.chat.complete_async(**request)
425-
for event in self._handle_non_streaming_response(response):
426-
yield self.format_chunk(event)
422+
async with mistralai.Mistral(**self.client_args) as client:
423+
response = await client.chat.complete_async(**request)
424+
for event in self._handle_non_streaming_response(response):
425+
yield self.format_chunk(event)
426+
427427
return
428428

429429
# Use the streaming API
430-
stream_response = await self.client.chat.stream_async(**request)
430+
async with mistralai.Mistral(**self.client_args) as client:
431+
stream_response = await client.chat.stream_async(**request)
431432

432-
yield self.format_chunk({"chunk_type": "message_start"})
433+
yield self.format_chunk({"chunk_type": "message_start"})
433434

434-
content_started = False
435-
tool_calls: dict[str, list[Any]] = {}
436-
accumulated_text = ""
435+
content_started = False
436+
tool_calls: dict[str, list[Any]] = {}
437+
accumulated_text = ""
437438

438-
async for chunk in stream_response:
439-
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
440-
choice = chunk.data.choices[0]
439+
async for chunk in stream_response:
440+
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
441+
choice = chunk.data.choices[0]
441442

442-
if hasattr(choice, "delta"):
443-
delta = choice.delta
443+
if hasattr(choice, "delta"):
444+
delta = choice.delta
444445

445-
if hasattr(delta, "content") and delta.content:
446-
if not content_started:
447-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
448-
content_started = True
446+
if hasattr(delta, "content") and delta.content:
447+
if not content_started:
448+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
449+
content_started = True
449450

450-
yield self.format_chunk(
451-
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
452-
)
453-
accumulated_text += delta.content
451+
yield self.format_chunk(
452+
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
453+
)
454+
accumulated_text += delta.content
454455

455-
if hasattr(delta, "tool_calls") and delta.tool_calls:
456-
for tool_call in delta.tool_calls:
457-
tool_id = tool_call.id
458-
tool_calls.setdefault(tool_id, []).append(tool_call)
456+
if hasattr(delta, "tool_calls") and delta.tool_calls:
457+
for tool_call in delta.tool_calls:
458+
tool_id = tool_call.id
459+
tool_calls.setdefault(tool_id, []).append(tool_call)
459460

460-
if hasattr(choice, "finish_reason") and choice.finish_reason:
461-
if content_started:
462-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
461+
if hasattr(choice, "finish_reason") and choice.finish_reason:
462+
if content_started:
463+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
463464

464-
for tool_deltas in tool_calls.values():
465-
yield self.format_chunk(
466-
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
467-
)
465+
for tool_deltas in tool_calls.values():
466+
yield self.format_chunk(
467+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
468+
)
468469

469-
for tool_delta in tool_deltas:
470-
if hasattr(tool_delta.function, "arguments"):
471-
yield self.format_chunk(
472-
{
473-
"chunk_type": "content_delta",
474-
"data_type": "tool",
475-
"data": tool_delta.function.arguments,
476-
}
477-
)
470+
for tool_delta in tool_deltas:
471+
if hasattr(tool_delta.function, "arguments"):
472+
yield self.format_chunk(
473+
{
474+
"chunk_type": "content_delta",
475+
"data_type": "tool",
476+
"data": tool_delta.function.arguments,
477+
}
478+
)
478479

479-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
480+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
480481

481-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
482+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
482483

483-
if hasattr(chunk, "usage"):
484-
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
484+
if hasattr(chunk, "usage"):
485+
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
485486

486487
except Exception as e:
487488
if "rate" in str(e).lower() or "429" in str(e):
@@ -518,7 +519,8 @@ async def structured_output(
518519
formatted_request["tool_choice"] = "any"
519520
formatted_request["parallel_tool_calls"] = False
520521

521-
response = await self.client.chat.complete_async(**formatted_request)
522+
async with mistralai.Mistral(**self.client_args) as client:
523+
response = await client.chat.complete_async(**formatted_request)
522524

523525
if response.choices and response.choices[0].message.tool_calls:
524526
tool_call = response.choices[0].message.tool_calls[0]

src/strands/models/ollama.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ def __init__(
6868
ollama_client_args: Additional arguments for the Ollama client.
6969
**model_config: Configuration options for the Ollama model.
7070
"""
71+
self.host = host
72+
self.client_args = ollama_client_args or {}
7173
self.config = OllamaModel.OllamaConfig(**model_config)
7274

7375
logger.debug("config=<%s> | initializing", self.config)
7476

75-
ollama_client_args = ollama_client_args if ollama_client_args is not None else {}
76-
77-
self.client = ollama.AsyncClient(host, **ollama_client_args)
78-
7977
@override
8078
def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: ignore
8179
"""Update the Ollama Model configuration with the provided arguments.
@@ -306,7 +304,8 @@ async def stream(
306304
logger.debug("invoking model")
307305
tool_requested = False
308306

309-
response = await self.client.chat(**request)
307+
client = ollama.AsyncClient(self.host, **self.client_args)
308+
response = await client.chat(**request)
310309

311310
logger.debug("got response from model")
312311
yield self.format_chunk({"chunk_type": "message_start"})
@@ -346,7 +345,9 @@ async def structured_output(
346345
formatted_request = self.format_request(messages=prompt)
347346
formatted_request["format"] = output_model.model_json_schema()
348347
formatted_request["stream"] = False
349-
response = await self.client.chat(**formatted_request)
348+
349+
client = ollama.AsyncClient(self.host, **self.client_args)
350+
response = await client.chat(**formatted_request)
350351

351352
try:
352353
content = response.message.content.strip()

src/strands/tools/mcp/mcp_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
from concurrent import futures
1717
from datetime import timedelta
1818
from types import TracebackType
19-
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar, Union
19+
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union
2020

2121
from mcp import ClientSession, ListToolsResult
2222
from mcp.types import CallToolResult as MCPCallToolResult
2323
from mcp.types import ImageContent as MCPImageContent
2424
from mcp.types import TextContent as MCPTextContent
2525

26+
from ...types import PaginatedList
2627
from ...types.exceptions import MCPClientInitializationError
2728
from ...types.media import ImageFormat
2829
from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus
@@ -140,7 +141,7 @@ async def _set_close_event() -> None:
140141
self._background_thread = None
141142
self._session_id = uuid.uuid4()
142143

143-
def list_tools_sync(self) -> List[MCPAgentTool]:
144+
def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]:
144145
"""Synchronously retrieves the list of available tools from the MCP server.
145146
146147
This method calls the asynchronous list_tools method on the MCP session
@@ -154,14 +155,14 @@ def list_tools_sync(self) -> List[MCPAgentTool]:
154155
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
155156

156157
async def _list_tools_async() -> ListToolsResult:
157-
return await self._background_thread_session.list_tools()
158+
return await self._background_thread_session.list_tools(cursor=pagination_token)
158159

159160
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
160161
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))
161162

162163
mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools]
163164
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
164-
return mcp_tools
165+
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)
165166

166167
def call_tool_sync(
167168
self,

src/strands/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
"""SDK type definitions."""
2+
3+
from .collections import PaginatedList
4+
5+
__all__ = ["PaginatedList"]

src/strands/types/collections.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Generic collection types for the Strands SDK."""
2+
3+
from typing import Generic, List, Optional, TypeVar
4+
5+
T = TypeVar("T")
6+
7+
8+
class PaginatedList(list, Generic[T]):
9+
"""A generic list-like object that includes a pagination token.
10+
11+
This maintains backwards compatibility by inheriting from list,
12+
so existing code that expects List[T] will continue to work.
13+
"""
14+
15+
def __init__(self, data: List[T], token: Optional[str] = None):
16+
"""Initialize a PaginatedList with data and an optional pagination token.
17+
18+
Args:
19+
data: The list of items to store.
20+
token: Optional pagination token for retrieving additional items.
21+
"""
22+
super().__init__(data)
23+
self.pagination_token = token

tests/strands/models/test_mistral.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
@pytest.fixture
1212
def mistral_client():
1313
with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls:
14-
yield mock_client_cls.return_value
14+
mock_client = unittest.mock.AsyncMock()
15+
mock_client_cls.return_value.__aenter__.return_value = mock_client
16+
yield mock_client
1517

1618

1719
@pytest.fixture
@@ -25,9 +27,7 @@ def max_tokens():
2527

2628

2729
@pytest.fixture
28-
def model(mistral_client, model_id, max_tokens):
29-
_ = mistral_client
30-
30+
def model(model_id, max_tokens):
3131
return MistralModel(model_id=model_id, max_tokens=max_tokens)
3232

3333

tests/strands/models/test_ollama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def host():
2626

2727

2828
@pytest.fixture
29-
def model(ollama_client, model_id, host):
30-
_ = ollama_client
31-
29+
def model(model_id, host):
3230
return OllamaModel(host, model_id=model_id)
3331

3432

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ def test_list_tools_sync(mock_transport, mock_session):
7171
with MCPClient(mock_transport["transport_callable"]) as client:
7272
tools = client.list_tools_sync()
7373

74-
mock_session.list_tools.assert_called_once()
74+
mock_session.list_tools.assert_called_once_with(cursor=None)
7575

7676
assert len(tools) == 1
7777
assert tools[0].tool_name == "test_tool"
78+
assert tools.pagination_token is None
7879

7980

8081
def test_list_tools_sync_session_not_active():
@@ -85,6 +86,34 @@ def test_list_tools_sync_session_not_active():
8586
client.list_tools_sync()
8687

8788

89+
def test_list_tools_sync_with_pagination_token(mock_transport, mock_session):
90+
"""Test that list_tools_sync correctly passes pagination token and returns next cursor."""
91+
mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}})
92+
mock_session.list_tools.return_value = ListToolsResult(tools=[mock_tool], nextCursor="next_page_token")
93+
94+
with MCPClient(mock_transport["transport_callable"]) as client:
95+
tools = client.list_tools_sync(pagination_token="current_page_token")
96+
97+
mock_session.list_tools.assert_called_once_with(cursor="current_page_token")
98+
assert len(tools) == 1
99+
assert tools[0].tool_name == "test_tool"
100+
assert tools.pagination_token == "next_page_token"
101+
102+
103+
def test_list_tools_sync_without_pagination_token(mock_transport, mock_session):
104+
"""Test that list_tools_sync works without pagination token and handles missing next cursor."""
105+
mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}})
106+
mock_session.list_tools.return_value = ListToolsResult(tools=[mock_tool]) # No nextCursor
107+
108+
with MCPClient(mock_transport["transport_callable"]) as client:
109+
tools = client.list_tools_sync()
110+
111+
mock_session.list_tools.assert_called_once_with(cursor=None)
112+
assert len(tools) == 1
113+
assert tools[0].tool_name == "test_tool"
114+
assert tools.pagination_token is None
115+
116+
88117
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
89118
def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_status):
90119
"""Test that call_tool_sync correctly handles success and error results."""

0 commit comments

Comments
 (0)