Skip to content

Commit 9f2f13d

Browse files
authored
models - mistral - init client on every request (#434)
1 parent 12b948a commit 9f2f13d

File tree

2 files changed

+56
-54
lines changed

2 files changed

+56
-54
lines changed

src/strands/models/mistral.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,9 @@ def __init__(
9090

9191
logger.debug("config=<%s> | initializing", self.config)
9292

93-
client_args = client_args or {}
93+
self.client_args = client_args or {}
9494
if api_key:
95-
client_args["api_key"] = api_key
96-
97-
self.client = mistralai.Mistral(**client_args)
95+
self.client_args["api_key"] = api_key
9896

9997
@override
10098
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
@@ -421,67 +419,70 @@ async def stream(
421419
logger.debug("got response from model")
422420
if not self.config.get("stream", True):
423421
# Use non-streaming API
424-
response = await self.client.chat.complete_async(**request)
425-
for event in self._handle_non_streaming_response(response):
426-
yield self.format_chunk(event)
422+
async with mistralai.Mistral(**self.client_args) as client:
423+
response = await client.chat.complete_async(**request)
424+
for event in self._handle_non_streaming_response(response):
425+
yield self.format_chunk(event)
426+
427427
return
428428

429429
# Use the streaming API
430-
stream_response = await self.client.chat.stream_async(**request)
430+
async with mistralai.Mistral(**self.client_args) as client:
431+
stream_response = await client.chat.stream_async(**request)
431432

432-
yield self.format_chunk({"chunk_type": "message_start"})
433+
yield self.format_chunk({"chunk_type": "message_start"})
433434

434-
content_started = False
435-
tool_calls: dict[str, list[Any]] = {}
436-
accumulated_text = ""
435+
content_started = False
436+
tool_calls: dict[str, list[Any]] = {}
437+
accumulated_text = ""
437438

438-
async for chunk in stream_response:
439-
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
440-
choice = chunk.data.choices[0]
439+
async for chunk in stream_response:
440+
if hasattr(chunk, "data") and hasattr(chunk.data, "choices") and chunk.data.choices:
441+
choice = chunk.data.choices[0]
441442

442-
if hasattr(choice, "delta"):
443-
delta = choice.delta
443+
if hasattr(choice, "delta"):
444+
delta = choice.delta
444445

445-
if hasattr(delta, "content") and delta.content:
446-
if not content_started:
447-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
448-
content_started = True
446+
if hasattr(delta, "content") and delta.content:
447+
if not content_started:
448+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
449+
content_started = True
449450

450-
yield self.format_chunk(
451-
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
452-
)
453-
accumulated_text += delta.content
451+
yield self.format_chunk(
452+
{"chunk_type": "content_delta", "data_type": "text", "data": delta.content}
453+
)
454+
accumulated_text += delta.content
454455

455-
if hasattr(delta, "tool_calls") and delta.tool_calls:
456-
for tool_call in delta.tool_calls:
457-
tool_id = tool_call.id
458-
tool_calls.setdefault(tool_id, []).append(tool_call)
456+
if hasattr(delta, "tool_calls") and delta.tool_calls:
457+
for tool_call in delta.tool_calls:
458+
tool_id = tool_call.id
459+
tool_calls.setdefault(tool_id, []).append(tool_call)
459460

460-
if hasattr(choice, "finish_reason") and choice.finish_reason:
461-
if content_started:
462-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
461+
if hasattr(choice, "finish_reason") and choice.finish_reason:
462+
if content_started:
463+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
463464

464-
for tool_deltas in tool_calls.values():
465-
yield self.format_chunk(
466-
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
467-
)
465+
for tool_deltas in tool_calls.values():
466+
yield self.format_chunk(
467+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
468+
)
468469

469-
for tool_delta in tool_deltas:
470-
if hasattr(tool_delta.function, "arguments"):
471-
yield self.format_chunk(
472-
{
473-
"chunk_type": "content_delta",
474-
"data_type": "tool",
475-
"data": tool_delta.function.arguments,
476-
}
477-
)
470+
for tool_delta in tool_deltas:
471+
if hasattr(tool_delta.function, "arguments"):
472+
yield self.format_chunk(
473+
{
474+
"chunk_type": "content_delta",
475+
"data_type": "tool",
476+
"data": tool_delta.function.arguments,
477+
}
478+
)
478479

479-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
480+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
480481

481-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
482+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
482483

483-
if hasattr(chunk, "usage"):
484-
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
484+
if hasattr(chunk, "usage"):
485+
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage})
485486

486487
except Exception as e:
487488
if "rate" in str(e).lower() or "429" in str(e):
@@ -518,7 +519,8 @@ async def structured_output(
518519
formatted_request["tool_choice"] = "any"
519520
formatted_request["parallel_tool_calls"] = False
520521

521-
response = await self.client.chat.complete_async(**formatted_request)
522+
async with mistralai.Mistral(**self.client_args) as client:
523+
response = await client.chat.complete_async(**formatted_request)
522524

523525
if response.choices and response.choices[0].message.tool_calls:
524526
tool_call = response.choices[0].message.tool_calls[0]

tests/strands/models/test_mistral.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
@pytest.fixture
1212
def mistral_client():
1313
with unittest.mock.patch.object(strands.models.mistral.mistralai, "Mistral") as mock_client_cls:
14-
yield mock_client_cls.return_value
14+
mock_client = unittest.mock.AsyncMock()
15+
mock_client_cls.return_value.__aenter__.return_value = mock_client
16+
yield mock_client
1517

1618

1719
@pytest.fixture
@@ -25,9 +27,7 @@ def max_tokens():
2527

2628

2729
@pytest.fixture
28-
def model(mistral_client, model_id, max_tokens):
29-
_ = mistral_client
30-
30+
def model(model_id, max_tokens):
3131
return MistralModel(model_id=model_id, max_tokens=max_tokens)
3232

3333

0 commit comments

Comments
 (0)