Skip to content

Commit

Permalink
feat: add dedicated Tool Call model from Mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
gcalmettes committed Oct 2, 2024
1 parent 563649a commit 9c38dbe
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
from random import choices
from string import ascii_letters, digits

import torch
from openai.types.chat import ChatCompletionContentPartParam
Expand All @@ -21,6 +23,8 @@
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]

ALPHANUMERIC = ascii_letters + digits

try:
from sphinx.ext.autodoc.mock import _MockModule

Expand Down Expand Up @@ -772,6 +776,17 @@ class ToolCall(OpenAIBaseModel):
function: FunctionCall


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class DeltaFunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
Expand All @@ -790,7 +805,7 @@ class ExtractedToolCallInformation(BaseModel):
tools_called: bool

# extracted tool calls
tool_calls: List[ToolCall]
tool_calls: List[ToolCall | MistralToolCall]

# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
Expand All @@ -800,7 +815,7 @@ class ExtractedToolCallInformation(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
tool_calls: List[ToolCall | MistralToolCall] = Field(default_factory=list)


class ChatCompletionLogProb(OpenAIBaseModel):
Expand Down

0 comments on commit 9c38dbe

Please sign in to comment.