Skip to content

Commit 9da14c8

Browse files
siddhantwaghjaleWorkshop Participant
authored andcommitted
fix: handle multiple tool calls in Mistral streaming responses (strands-agents#384)
1 parent 976ca32 commit 9da14c8

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

src/strands/models/mistral.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
422422
yield {"chunk_type": "message_start"}
423423

424424
content_started = False
425-
current_tool_calls: dict[str, dict[str, str]] = {}
425+
tool_calls: dict[str, list[Any]] = {}
426426
accumulated_text = ""
427427

428428
async for chunk in stream_response:
@@ -443,24 +443,23 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
443443
if hasattr(delta, "tool_calls") and delta.tool_calls:
444444
for tool_call in delta.tool_calls:
445445
tool_id = tool_call.id
446+
tool_calls.setdefault(tool_id, []).append(tool_call)
446447

447-
if tool_id not in current_tool_calls:
448-
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call}
449-
current_tool_calls[tool_id] = {"name": tool_call.function.name, "arguments": ""}
448+
if hasattr(choice, "finish_reason") and choice.finish_reason:
449+
if content_started:
450+
yield {"chunk_type": "content_stop", "data_type": "text"}
451+
452+
for tool_deltas in tool_calls.values():
453+
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
450454

451-
if hasattr(tool_call.function, "arguments"):
452-
current_tool_calls[tool_id]["arguments"] += tool_call.function.arguments
455+
for tool_delta in tool_deltas:
456+
if hasattr(tool_delta.function, "arguments"):
453457
yield {
454458
"chunk_type": "content_delta",
455459
"data_type": "tool",
456-
"data": tool_call.function.arguments,
460+
"data": tool_delta.function.arguments,
457461
}
458462

459-
if hasattr(choice, "finish_reason") and choice.finish_reason:
460-
if content_started:
461-
yield {"chunk_type": "content_stop", "data_type": "text"}
462-
463-
for _ in current_tool_calls:
464463
yield {"chunk_type": "content_stop", "data_type": "tool"}
465464

466465
yield {"chunk_type": "message_stop", "data": choice.finish_reason}

tests-integ/test_model_mistral.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,41 +78,32 @@ class Weather(BaseModel):
7878

7979
@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
8080
def test_agent_invoke(agent):
81-
# TODO: https://github.com/strands-agents/sdk-python/issues/374
82-
# result = streaming_agent("What is the time and weather in New York?")
83-
result = agent("What is the time in New York?")
81+
result = agent("What is the time and weather in New York?")
8482
text = result.message["content"][0]["text"].lower()
8583

86-
# assert all(string in text for string in ["12:00", "sunny"])
87-
assert all(string in text for string in ["12:00"])
84+
assert all(string in text for string in ["12:00", "sunny"])
8885

8986

9087
@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
9188
@pytest.mark.asyncio
9289
async def test_agent_invoke_async(agent):
93-
# TODO: https://github.com/strands-agents/sdk-python/issues/374
94-
# result = await streaming_agent.invoke_async("What is the time and weather in New York?")
95-
result = await agent.invoke_async("What is the time in New York?")
90+
result = await agent.invoke_async("What is the time and weather in New York?")
9691
text = result.message["content"][0]["text"].lower()
9792

98-
# assert all(string in text for string in ["12:00", "sunny"])
99-
assert all(string in text for string in ["12:00"])
93+
assert all(string in text for string in ["12:00", "sunny"])
10094

10195

10296
@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
10397
@pytest.mark.asyncio
10498
async def test_agent_stream_async(agent):
105-
# TODO: https://github.com/strands-agents/sdk-python/issues/374
106-
# stream = streaming_agent.stream_async("What is the time and weather in New York?")
107-
stream = agent.stream_async("What is the time in New York?")
99+
stream = agent.stream_async("What is the time and weather in New York?")
108100
async for event in stream:
109101
_ = event
110102

111103
result = event["result"]
112104
text = result.message["content"][0]["text"].lower()
113105

114-
# assert all(string in text for string in ["12:00", "sunny"])
115-
assert all(string in text for string in ["12:00"])
106+
assert all(string in text for string in ["12:00", "sunny"])
116107

117108

118109
@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")

0 commit comments

Comments
 (0)