1
1
import unittest .mock
2
2
3
+ import pydantic
3
4
import pytest
4
5
5
6
import strands
@@ -58,6 +59,15 @@ def system_prompt():
58
59
return "You are a helpful assistant"
59
60
60
61
62
+ @pytest .fixture
63
+ def test_output_model_cls ():
64
+ class TestOutputModel (pydantic .BaseModel ):
65
+ name : str
66
+ age : int
67
+
68
+ return TestOutputModel
69
+
70
+
61
71
def test__init__model_configs (mistral_client , model_id , max_tokens ):
62
72
_ = mistral_client
63
73
@@ -440,35 +450,24 @@ def test_stream_other_error(mistral_client, model):
440
450
list (model .stream ({}))
441
451
442
452
443
- def test_structured_output_success (mistral_client , model ):
444
- from pydantic import BaseModel
445
-
446
- class TestModel (BaseModel ):
447
- name : str
448
- age : int
453
+ def test_structured_output_success (mistral_client , model , test_output_model_cls ):
454
+ messages = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
449
455
450
- # Mock successful response
451
456
mock_response = unittest .mock .Mock ()
452
457
mock_response .choices = [unittest .mock .Mock ()]
453
458
mock_response .choices [0 ].message .tool_calls = [unittest .mock .Mock ()]
454
459
mock_response .choices [0 ].message .tool_calls [0 ].function .arguments = '{"name": "John", "age": 30}'
455
460
456
461
mistral_client .chat .complete .return_value = mock_response
457
462
458
- prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
459
- result = model .structured_output (TestModel , prompt )
460
-
461
- assert isinstance (result , TestModel )
462
- assert result .name == "John"
463
- assert result .age == 30
464
-
463
+ stream = model .structured_output (test_output_model_cls , messages )
465
464
466
- def test_structured_output_no_tool_calls (mistral_client , model ):
467
- from pydantic import BaseModel
465
+ tru_result = list (stream )[- 1 ]
466
+ exp_result = {"output" : test_output_model_cls (name = "John" , age = 30 )}
467
+ assert tru_result == exp_result
468
468
469
- class TestModel (BaseModel ):
470
- name : str
471
469
470
+ def test_structured_output_no_tool_calls (mistral_client , model , test_output_model_cls ):
472
471
mock_response = unittest .mock .Mock ()
473
472
mock_response .choices = [unittest .mock .Mock ()]
474
473
mock_response .choices [0 ].message .tool_calls = None
@@ -478,15 +477,11 @@ class TestModel(BaseModel):
478
477
prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
479
478
480
479
with pytest .raises (ValueError , match = "No tool calls found in response" ):
481
- model .structured_output (TestModel , prompt )
482
-
480
+ stream = model .structured_output (test_output_model_cls , prompt )
481
+ next ( stream )
483
482
484
- def test_structured_output_invalid_json (mistral_client , model ):
485
- from pydantic import BaseModel
486
-
487
- class TestModel (BaseModel ):
488
- name : str
489
483
484
+ def test_structured_output_invalid_json (mistral_client , model , test_output_model_cls ):
490
485
mock_response = unittest .mock .Mock ()
491
486
mock_response .choices = [unittest .mock .Mock ()]
492
487
mock_response .choices [0 ].message .tool_calls = [unittest .mock .Mock ()]
@@ -497,4 +492,5 @@ class TestModel(BaseModel):
497
492
prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
498
493
499
494
with pytest .raises (ValueError , match = "Failed to parse tool call arguments into model" ):
500
- model .structured_output (TestModel , prompt )
495
+ stream = model .structured_output (test_output_model_cls , prompt )
496
+ next (stream )
0 commit comments