Skip to content

Commit 3f56fbb

Browse files
zastrowmjsamuel1
authored andcommitted
fix: Migrate Mistral structured_output to an iterator (strands-agents#305)
Mistral was merged after the latest iterator changes (strands-agents#291) so it didn't have the latest and greatest --------- Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent a3a3ccc commit 3f56fbb

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

src/strands/models/mistral.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import base64
77
import json
88
import logging
9-
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union
9+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union
1010

1111
from mistralai import Mistral
1212
from pydantic import BaseModel
@@ -472,7 +472,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
472472
@override
473473
def structured_output(
474474
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
475-
) -> T:
475+
) -> Generator[dict[str, Union[T, Any]], None, None]:
476476
"""Get structured output from the model.
477477
478478
Args:
@@ -507,7 +507,8 @@ def structured_output(
507507
arguments = json.loads(tool_call.function.arguments)
508508
else:
509509
arguments = tool_call.function.arguments
510-
return output_model(**arguments)
510+
yield {"output": output_model(**arguments)}
511+
return
511512
except (json.JSONDecodeError, TypeError, ValueError) as e:
512513
raise ValueError(f"Failed to parse tool call arguments into model: {e}") from e
513514

tests/strands/models/test_mistral.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest.mock
22

3+
import pydantic
34
import pytest
45

56
import strands
@@ -58,6 +59,15 @@ def system_prompt():
5859
return "You are a helpful assistant"
5960

6061

62+
@pytest.fixture
63+
def test_output_model_cls():
64+
class TestOutputModel(pydantic.BaseModel):
65+
name: str
66+
age: int
67+
68+
return TestOutputModel
69+
70+
6171
def test__init__model_configs(mistral_client, model_id, max_tokens):
6272
_ = mistral_client
6373

@@ -440,35 +450,24 @@ def test_stream_other_error(mistral_client, model):
440450
list(model.stream({}))
441451

442452

443-
def test_structured_output_success(mistral_client, model):
444-
from pydantic import BaseModel
445-
446-
class TestModel(BaseModel):
447-
name: str
448-
age: int
453+
def test_structured_output_success(mistral_client, model, test_output_model_cls):
454+
messages = [{"role": "user", "content": [{"text": "Extract data"}]}]
449455

450-
# Mock successful response
451456
mock_response = unittest.mock.Mock()
452457
mock_response.choices = [unittest.mock.Mock()]
453458
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
454459
mock_response.choices[0].message.tool_calls[0].function.arguments = '{"name": "John", "age": 30}'
455460

456461
mistral_client.chat.complete.return_value = mock_response
457462

458-
prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
459-
result = model.structured_output(TestModel, prompt)
460-
461-
assert isinstance(result, TestModel)
462-
assert result.name == "John"
463-
assert result.age == 30
464-
463+
stream = model.structured_output(test_output_model_cls, messages)
465464

466-
def test_structured_output_no_tool_calls(mistral_client, model):
467-
from pydantic import BaseModel
465+
tru_result = list(stream)[-1]
466+
exp_result = {"output": test_output_model_cls(name="John", age=30)}
467+
assert tru_result == exp_result
468468

469-
class TestModel(BaseModel):
470-
name: str
471469

470+
def test_structured_output_no_tool_calls(mistral_client, model, test_output_model_cls):
472471
mock_response = unittest.mock.Mock()
473472
mock_response.choices = [unittest.mock.Mock()]
474473
mock_response.choices[0].message.tool_calls = None
@@ -478,15 +477,11 @@ class TestModel(BaseModel):
478477
prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
479478

480479
with pytest.raises(ValueError, match="No tool calls found in response"):
481-
model.structured_output(TestModel, prompt)
482-
480+
stream = model.structured_output(test_output_model_cls, prompt)
481+
next(stream)
483482

484-
def test_structured_output_invalid_json(mistral_client, model):
485-
from pydantic import BaseModel
486-
487-
class TestModel(BaseModel):
488-
name: str
489483

484+
def test_structured_output_invalid_json(mistral_client, model, test_output_model_cls):
490485
mock_response = unittest.mock.Mock()
491486
mock_response.choices = [unittest.mock.Mock()]
492487
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
@@ -497,4 +492,5 @@ class TestModel(BaseModel):
497492
prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
498493

499494
with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"):
500-
model.structured_output(TestModel, prompt)
495+
stream = model.structured_output(test_output_model_cls, prompt)
496+
next(stream)

0 commit comments

Comments
 (0)