1
1
import unittest.mock
2
2
3
+ import pydantic
3
4
import pytest
4
5
5
6
import strands
@@ -58,6 +59,15 @@ def system_prompt():
58
59
return "You are a helpful assistant"
59
60
60
61
62
+ @pytest.fixture
63
+ def test_output_model_cls():
64
+ class TestOutputModel(pydantic.BaseModel):
65
+ name: str
66
+ age: int
67
+
68
+ return TestOutputModel
69
+
70
+
61
71
def test__init__model_configs(mistral_client, model_id, max_tokens):
62
72
_ = mistral_client
63
73
@@ -440,35 +450,25 @@ def test_stream_other_error(mistral_client, model):
440
450
list(model.stream({}))
441
451
442
452
443
- def test_structured_output_success(mistral_client, model):
444
- from pydantic import BaseModel
445
-
446
- class TestModel(BaseModel):
447
- name: str
448
- age: int
453
+ def test_structured_output_success(mistral_client, model, test_output_model_cls):
454
+ messages = [{"role": "user", "content": [{"text": "Extract data"}]}]
449
455
450
- # Mock successful response
451
456
mock_response = unittest.mock.Mock()
457
+ mock_response.bloach = True
452
458
mock_response.choices = [unittest.mock.Mock()]
453
459
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
454
460
mock_response.choices[0].message.tool_calls[0].function.arguments = '{"name": "John", "age": 30}'
455
461
456
462
mistral_client.chat.complete.return_value = mock_response
457
463
458
- prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
459
- result = model.structured_output(TestModel, prompt)
460
-
461
- assert isinstance(result, TestModel)
462
- assert result.name == "John"
463
- assert result.age == 30
464
-
464
+ stream = model.structured_output(test_output_model_cls, messages)
465
465
466
- def test_structured_output_no_tool_calls(mistral_client, model):
467
- from pydantic import BaseModel
466
+ tru_result = list(stream)[-1]
467
+ exp_result = {"output": test_output_model_cls(name="John", age=30)}
468
+ assert tru_result == exp_result
468
469
469
- class TestModel(BaseModel):
470
- name: str
471
470
471
+ def test_structured_output_no_tool_calls(mistral_client, model, test_output_model_cls):
472
472
mock_response = unittest.mock.Mock()
473
473
mock_response.choices = [unittest.mock.Mock()]
474
474
mock_response.choices[0].message.tool_calls = None
@@ -478,15 +478,11 @@ class TestModel(BaseModel):
478
478
prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
479
479
480
480
with pytest.raises(ValueError, match="No tool calls found in response"):
481
- model.structured_output(TestModel , prompt)
482
-
481
+ stream = model.structured_output(test_output_model_cls , prompt)
482
+ next(stream)
483
483
484
- def test_structured_output_invalid_json(mistral_client, model):
485
- from pydantic import BaseModel
486
-
487
- class TestModel(BaseModel):
488
- name: str
489
484
485
+ def test_structured_output_invalid_json(mistral_client, model, test_output_model_cls):
490
486
mock_response = unittest.mock.Mock()
491
487
mock_response.choices = [unittest.mock.Mock()]
492
488
mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()]
@@ -497,4 +493,5 @@ class TestModel(BaseModel):
497
493
prompt = [{"role": "user", "content": [{"text": "Extract data"}]}]
498
494
499
495
with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"):
500
- model.structured_output(TestModel, prompt)
496
+ stream = model.structured_output(test_output_model_cls, prompt)
497
+ next(stream)
0 commit comments