Skip to content

models - mistral - async #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union

from mistralai import Mistral
import mistralai
from pydantic import BaseModel
from typing_extensions import TypedDict, Unpack, override

Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
if api_key:
client_args["api_key"] = api_key

self.client = Mistral(**client_args)
self.client = mistralai.Mistral(**client_args)

@override
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
Expand Down Expand Up @@ -408,21 +408,21 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
try:
if not self.config.get("stream", True):
# Use non-streaming API
response = self.client.chat.complete(**request)
response = await self.client.chat.complete_async(**request)
for event in self._handle_non_streaming_response(response):
yield event
return

# Use the streaming API
stream_response = self.client.chat.stream(**request)
stream_response = await self.client.chat.stream_async(**request)

yield {"chunk_type": "message_start"}

content_started = False
current_tool_calls: dict[str, dict[str, str]] = {}
accumulated_text = ""

for chunk in stream_response:
async for chunk in stream_response:
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
choice = chunk.data.choices[0]

Expand Down Expand Up @@ -499,7 +499,7 @@ async def structured_output(
formatted_request["tool_choice"] = "any"
formatted_request["parallel_tool_calls"] = False

response = self.client.chat.complete(**formatted_request)
response = await self.client.chat.complete_async(**formatted_request)

if response.choices and response.choices[0].message.tool_calls:
tool_call = response.choices[0].message.tool_calls[0]
Expand Down
143 changes: 59 additions & 84 deletions tests-integ/test_model_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from strands.models.mistral import MistralModel


@pytest.fixture
@pytest.fixture(scope="module")
def streaming_model():
return MistralModel(
model_id="mistral-medium-latest",
Expand All @@ -20,7 +20,7 @@ def streaming_model():
)


@pytest.fixture
@pytest.fixture(scope="module")
def non_streaming_model():
return MistralModel(
model_id="mistral-medium-latest",
Expand All @@ -32,126 +32,101 @@ def non_streaming_model():
)


@pytest.fixture
@pytest.fixture(scope="module")
def system_prompt():
return "You are an AI assistant that provides helpful and accurate information."


@pytest.fixture
def calculator_tool():
@strands.tool
def calculator(expression: str) -> float:
"""Calculate the result of a mathematical expression."""
return eval(expression)

return calculator


@pytest.fixture
def weather_tools():
@pytest.fixture(scope="module")
def tools():
@strands.tool
def tool_time() -> str:
"""Get the current time."""
return "12:00"

@strands.tool
def tool_weather() -> str:
"""Get the current weather."""
return "sunny"

return [tool_time, tool_weather]


@pytest.fixture
def streaming_agent(streaming_model):
return Agent(model=streaming_model)
@pytest.fixture(scope="module")
def streaming_agent(streaming_model, tools):
return Agent(model=streaming_model, tools=tools)


@pytest.fixture
def non_streaming_agent(non_streaming_model):
return Agent(model=non_streaming_model)
@pytest.fixture(scope="module")
def non_streaming_agent(non_streaming_model, tools):
return Agent(model=non_streaming_model, tools=tools)


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_streaming_agent_basic(streaming_agent):
"""Test basic streaming agent functionality."""
result = streaming_agent("Tell me about Agentic AI in one sentence.")
@pytest.fixture(params=["streaming_agent", "non_streaming_agent"])
def agent(request):
return request.getfixturevalue(request.param)

assert len(str(result)) > 0
assert hasattr(result, "message")
assert "content" in result.message

@pytest.fixture(scope="module")
def weather():
class Weather(BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_non_streaming_agent_basic(non_streaming_agent):
"""Test basic non-streaming agent functionality."""
result = non_streaming_agent("Tell me about Agentic AI in one sentence.")
time: str
weather: str

assert len(str(result)) > 0
assert hasattr(result, "message")
assert "content" in result.message
return Weather(time="12:00", weather="sunny")


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_tool_use_streaming(streaming_model):
"""Test tool use with streaming model."""

@strands.tool
def calculator(expression: str) -> float:
"""Calculate the result of a mathematical expression."""
return eval(expression)

agent = Agent(model=streaming_model, tools=[calculator])
result = agent("What is the square root of 1764")
def test_agent_invoke(agent):
# TODO: https://github.com/strands-agents/sdk-python/issues/374
# result = streaming_agent("What is the time and weather in New York?")
result = agent("What is the time in New York?")
text = result.message["content"][0]["text"].lower()

# Verify the result contains the calculation
text_content = str(result).lower()
assert "42" in text_content
# assert all(string in text for string in ["12:00", "sunny"])
assert all(string in text for string in ["12:00"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_tool_use_non_streaming(non_streaming_model):
"""Test tool use with non-streaming model."""
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
# TODO: https://github.com/strands-agents/sdk-python/issues/374
# result = await streaming_agent.invoke_async("What is the time and weather in New York?")
result = await agent.invoke_async("What is the time in New York?")
text = result.message["content"][0]["text"].lower()

@strands.tool
def calculator(expression: str) -> float:
"""Calculate the result of a mathematical expression."""
return eval(expression)

agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False)
result = agent("What is the square root of 1764")

text_content = str(result).lower()
assert "42" in text_content
# assert all(string in text for string in ["12:00", "sunny"])
assert all(string in text for string in ["12:00"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_structured_output_streaming(streaming_model):
"""Test structured output with streaming model."""

class Weather(BaseModel):
time: str
weather: str
@pytest.mark.asyncio
async def test_agent_stream_async(agent):
# TODO: https://github.com/strands-agents/sdk-python/issues/374
# stream = streaming_agent.stream_async("What is the time and weather in New York?")
stream = agent.stream_async("What is the time in New York?")
async for event in stream:
_ = event

agent = Agent(model=streaming_model)
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
result = event["result"]
text = result.message["content"][0]["text"].lower()

assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"
# assert all(string in text for string in ["12:00", "sunny"])
assert all(string in text for string in ["12:00"])


@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
def test_structured_output_non_streaming(non_streaming_model):
"""Test structured output with non-streaming model."""
def test_agent_structured_output(non_streaming_agent, weather):
tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather

class Weather(BaseModel):
time: str
weather: str

agent = Agent(model=non_streaming_model)
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")

assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"
@pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")
@pytest.mark.asyncio
async def test_agent_structured_output_async(non_streaming_agent, weather):
tru_weather = await non_streaming_agent.structured_output_async(
type(weather), "The time is 12:00 and the weather is sunny"
)
exp_weather = weather
assert tru_weather == exp_weather
45 changes: 39 additions & 6 deletions tests/strands/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.fixture
def mistral_client():
with unittest.mock.patch.object(strands.models.mistral, "Mistral") as mock_client_cls:
with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls:
yield mock_client_cls.return_value


Expand Down Expand Up @@ -436,17 +436,50 @@ def test_format_chunk_unknown(model):
model.format_chunk(event)


@pytest.mark.asyncio
async def test_stream(mistral_client, model, agenerator, alist):
mock_event = unittest.mock.Mock(
data=unittest.mock.Mock(
choices=[
unittest.mock.Mock(
delta=unittest.mock.Mock(content="test stream", tool_calls=None),
finish_reason="end_turn",
)
]
),
usage="usage",
)

mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))

request = {"model": "m1"}
response = model.stream(request)

tru_events = await alist(response)
exp_events = [
{"chunk_type": "message_start"},
{"chunk_type": "content_start", "data_type": "text"},
{"chunk_type": "content_delta", "data_type": "text", "data": "test stream"},
{"chunk_type": "content_stop", "data_type": "text"},
{"chunk_type": "message_stop", "data": "end_turn"},
{"chunk_type": "metadata", "data": "usage"},
]
assert tru_events == exp_events

mistral_client.chat.stream_async.assert_called_once_with(**request)


@pytest.mark.asyncio
async def test_stream_rate_limit_error(mistral_client, model, alist):
mistral_client.chat.stream.side_effect = Exception("rate limit exceeded (429)")
mistral_client.chat.stream_async.side_effect = Exception("rate limit exceeded (429)")

with pytest.raises(ModelThrottledException, match="rate limit exceeded"):
await alist(model.stream({}))


@pytest.mark.asyncio
async def test_stream_other_error(mistral_client, model, alist):
mistral_client.chat.stream.side_effect = Exception("some other error")
mistral_client.chat.stream_async.side_effect = Exception("some other error")

with pytest.raises(Exception, match="some other error"):
await alist(model.stream({}))
Expand All @@ -461,7 +494,7 @@ async def test_structured_output_success(mistral_client, model, test_output_mode
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
mock_response.choices[0].message.tool_calls[0].function.arguments = '{"name": "John", "age": 30}'

mistral_client.chat.complete.return_value = mock_response
mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response)

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)
Expand All @@ -477,7 +510,7 @@ async def test_structured_output_no_tool_calls(mistral_client, model, test_outpu
mock_response.choices = [unittest.mock.Mock()]
mock_response.choices[0].message.tool_calls = None

mistral_client.chat.complete.return_value = mock_response
mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response)

prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]

Expand All @@ -493,7 +526,7 @@ async def test_structured_output_invalid_json(mistral_client, model, test_output
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
mock_response.choices[0].message.tool_calls[0].function.arguments = "invalid json"

mistral_client.chat.complete.return_value = mock_response
mistral_client.chat.complete_async = unittest.mock.AsyncMock(return_value=mock_response)

prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]

Expand Down