@@ -90,11 +90,9 @@ def __init__(
90
90
91
91
logger .debug ("config=<%s> | initializing" , self .config )
92
92
93
- client_args = client_args or {}
93
+ self . client_args = client_args or {}
94
94
if api_key :
95
- client_args ["api_key" ] = api_key
96
-
97
- self .client = mistralai .Mistral (** client_args )
95
+ self .client_args ["api_key" ] = api_key
98
96
99
97
@override
100
98
def update_config (self , ** model_config : Unpack [MistralConfig ]) -> None : # type: ignore
@@ -421,67 +419,70 @@ async def stream(
421
419
logger .debug ("got response from model" )
422
420
if not self .config .get ("stream" , True ):
423
421
# Use non-streaming API
424
- response = await self .client .chat .complete_async (** request )
425
- for event in self ._handle_non_streaming_response (response ):
426
- yield self .format_chunk (event )
422
+ async with mistralai .Mistral (** self .client_args ) as client :
423
+ response = await client .chat .complete_async (** request )
424
+ for event in self ._handle_non_streaming_response (response ):
425
+ yield self .format_chunk (event )
426
+
427
427
return
428
428
429
429
# Use the streaming API
430
- stream_response = await self .client .chat .stream_async (** request )
430
+ async with mistralai .Mistral (** self .client_args ) as client :
431
+ stream_response = await client .chat .stream_async (** request )
431
432
432
- yield self .format_chunk ({"chunk_type" : "message_start" })
433
+ yield self .format_chunk ({"chunk_type" : "message_start" })
433
434
434
- content_started = False
435
- tool_calls : dict [str , list [Any ]] = {}
436
- accumulated_text = ""
435
+ content_started = False
436
+ tool_calls : dict [str , list [Any ]] = {}
437
+ accumulated_text = ""
437
438
438
- async for chunk in stream_response :
439
- if hasattr (chunk , "data" ) and hasattr (chunk .data , "choices" ) and chunk .data .choices :
440
- choice = chunk .data .choices [0 ]
439
+ async for chunk in stream_response :
440
+ if hasattr (chunk , "data" ) and hasattr (chunk .data , "choices" ) and chunk .data .choices :
441
+ choice = chunk .data .choices [0 ]
441
442
442
- if hasattr (choice , "delta" ):
443
- delta = choice .delta
443
+ if hasattr (choice , "delta" ):
444
+ delta = choice .delta
444
445
445
- if hasattr (delta , "content" ) and delta .content :
446
- if not content_started :
447
- yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
448
- content_started = True
446
+ if hasattr (delta , "content" ) and delta .content :
447
+ if not content_started :
448
+ yield self .format_chunk ({"chunk_type" : "content_start" , "data_type" : "text" })
449
+ content_started = True
449
450
450
- yield self .format_chunk (
451
- {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : delta .content }
452
- )
453
- accumulated_text += delta .content
451
+ yield self .format_chunk (
452
+ {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : delta .content }
453
+ )
454
+ accumulated_text += delta .content
454
455
455
- if hasattr (delta , "tool_calls" ) and delta .tool_calls :
456
- for tool_call in delta .tool_calls :
457
- tool_id = tool_call .id
458
- tool_calls .setdefault (tool_id , []).append (tool_call )
456
+ if hasattr (delta , "tool_calls" ) and delta .tool_calls :
457
+ for tool_call in delta .tool_calls :
458
+ tool_id = tool_call .id
459
+ tool_calls .setdefault (tool_id , []).append (tool_call )
459
460
460
- if hasattr (choice , "finish_reason" ) and choice .finish_reason :
461
- if content_started :
462
- yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "text" })
461
+ if hasattr (choice , "finish_reason" ) and choice .finish_reason :
462
+ if content_started :
463
+ yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "text" })
463
464
464
- for tool_deltas in tool_calls .values ():
465
- yield self .format_chunk (
466
- {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
467
- )
465
+ for tool_deltas in tool_calls .values ():
466
+ yield self .format_chunk (
467
+ {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
468
+ )
468
469
469
- for tool_delta in tool_deltas :
470
- if hasattr (tool_delta .function , "arguments" ):
471
- yield self .format_chunk (
472
- {
473
- "chunk_type" : "content_delta" ,
474
- "data_type" : "tool" ,
475
- "data" : tool_delta .function .arguments ,
476
- }
477
- )
470
+ for tool_delta in tool_deltas :
471
+ if hasattr (tool_delta .function , "arguments" ):
472
+ yield self .format_chunk (
473
+ {
474
+ "chunk_type" : "content_delta" ,
475
+ "data_type" : "tool" ,
476
+ "data" : tool_delta .function .arguments ,
477
+ }
478
+ )
478
479
479
- yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
480
+ yield self .format_chunk ({"chunk_type" : "content_stop" , "data_type" : "tool" })
480
481
481
- yield self .format_chunk ({"chunk_type" : "message_stop" , "data" : choice .finish_reason })
482
+ yield self .format_chunk ({"chunk_type" : "message_stop" , "data" : choice .finish_reason })
482
483
483
- if hasattr (chunk , "usage" ):
484
- yield self .format_chunk ({"chunk_type" : "metadata" , "data" : chunk .usage })
484
+ if hasattr (chunk , "usage" ):
485
+ yield self .format_chunk ({"chunk_type" : "metadata" , "data" : chunk .usage })
485
486
486
487
except Exception as e :
487
488
if "rate" in str (e ).lower () or "429" in str (e ):
@@ -518,7 +519,8 @@ async def structured_output(
518
519
formatted_request ["tool_choice" ] = "any"
519
520
formatted_request ["parallel_tool_calls" ] = False
520
521
521
- response = await self .client .chat .complete_async (** formatted_request )
522
+ async with mistralai .Mistral (** self .client_args ) as client :
523
+ response = await client .chat .complete_async (** formatted_request )
522
524
523
525
if response .choices and response .choices [0 ].message .tool_calls :
524
526
tool_call = response .choices [0 ].message .tool_calls [0 ]
0 commit comments