|
1 | 1 | import json
|
2 | 2 | import re
|
| 3 | +from random import choices |
| 4 | +from string import ascii_letters, digits |
3 | 5 | from typing import Dict, List, Sequence, Union
|
4 | 6 |
|
5 | 7 | import partial_json_parser
|
6 | 8 | from partial_json_parser.core.options import Allow
|
| 9 | +from pydantic import Field |
7 | 10 |
|
8 | 11 | from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
9 | 12 | DeltaToolCall,
|
|
19 | 22 |
|
20 | 23 | logger = init_logger(__name__)
|
21 | 24 |
|
| 25 | +ALPHANUMERIC = ascii_letters + digits |
| 26 | + |
| 27 | + |
| 28 | +class MistralToolCall(ToolCall): |
| 29 | + id: str = Field( |
| 30 | + default_factory=lambda: MistralToolCall.generate_random_id()) |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + def generate_random_id(): |
| 34 | + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. |
| 35 | + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 |
| 36 | + return "".join(choices(ALPHANUMERIC, k=9)) |
| 37 | + |
22 | 38 |
|
23 | 39 | class MistralToolParser(ToolParser):
|
24 | 40 | """
|
@@ -71,8 +87,8 @@ def extract_tool_calls(self,
|
71 | 87 | # load the JSON, and then use it to build the Function and
|
72 | 88 | # Tool Call
|
73 | 89 | function_call_arr = json.loads(raw_tool_call)
|
74 |
| - tool_calls: List[ToolCall] = [ |
75 |
| - ToolCall( |
| 90 | + tool_calls: List[MistralToolCall] = [ |
| 91 | + MistralToolCall( |
76 | 92 | type="function",
|
77 | 93 | function=FunctionCall(
|
78 | 94 | name=raw_function_call["name"],
|
|
0 commit comments