Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mistral format tools #5473

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llamafactory/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def create_chat_completion_response(
if isinstance(result, list):
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
function = Function(name=tool.name, arguments=tool.arguments)
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))

response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
Expand Down
26 changes: 8 additions & 18 deletions src/llamafactory/data/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Union

from typing_extensions import override

from .data_utils import SLOTS
from .tool_utils import get_tool_utils


if TYPE_CHECKING:
from .tool_utils import FunctionCall
from .tool_utils import FunctionCall, get_tool_utils


@dataclass
Expand Down Expand Up @@ -98,35 +94,29 @@ def apply(self, **kwargs) -> SLOTS:
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
self.tool_utils = get_tool_utils(self.tool_format)

@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]

for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)

except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string

elements = []
for slot in self.slots:
if slot == "{{content}}":
for name, arguments in functions:
for slot in self.function_slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)

Expand Down
23 changes: 18 additions & 5 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
from .tool_utils import FunctionCall


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -83,7 +84,7 @@ def encode_multiturn(
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]

def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
Expand Down Expand Up @@ -244,7 +245,7 @@ def _register_template(
)
```
"""
template_class = Llama2Template if name.startswith("llama2") else Template
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
Expand Down Expand Up @@ -854,7 +855,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
# copied from mistral template
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
Expand Down Expand Up @@ -902,7 +907,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
# copied from mistral template
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
Expand Down Expand Up @@ -939,7 +948,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:

_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)

Expand Down
106 changes: 78 additions & 28 deletions src/llamafactory/data/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
import json
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Tuple, Union

from typing_extensions import override

from .data_utils import SLOTS


FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
class FunctionCall(NamedTuple):
name: str
arguments: str


DEFAULT_TOOL_PROMPT = (
Expand All @@ -38,13 +39,11 @@
"```\n"
)


GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)


LLAMA3_TOOL_PROMPT = (
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
Expand All @@ -61,35 +60,30 @@ class ToolUtils(ABC):

@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS:
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Gets a list of slots corresponding to a single function call.
Generates the system message describing all the available tools.
"""
...

@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
r"""
Generates the system message describing all the available tools.
Generates the assistant message including all the tool calls.
"""
...

@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
Extracts all the function calls from the assistant message.
"""
...


class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]

@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
Expand Down Expand Up @@ -124,6 +118,15 @@ def tool_formatter(tools: List[Dict[str, Any]]) -> str:

return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))

@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"

return [function_text]

@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
Expand All @@ -138,19 +141,14 @@ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content

return results


class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]

@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
Expand All @@ -162,6 +160,14 @@ def tool_formatter(tools: List[Dict[str, Any]]) -> str:

return GLM4_TOOL_PROMPT.format(tool_text=tool_text)

@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")

return [f"{functions[0].name}\n{functions[0].arguments}"]

@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
Expand All @@ -174,7 +180,7 @@ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
except json.JSONDecodeError:
return content

return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]


class Llama3ToolUtils(ToolUtils):
Expand All @@ -184,11 +190,6 @@ class Llama3ToolUtils(ToolUtils):
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""

@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]

@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
Expand All @@ -200,6 +201,14 @@ def tool_formatter(tools: List[Dict[str, Any]]) -> str:

return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)

@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("Llama 3 does not support parallel functions.")

return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']

@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
Expand All @@ -211,13 +220,54 @@ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "name" not in tool or "parameters" not in tool:
return content

return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]


class MistralToolUtils(ToolUtils):
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})

return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"

@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')

return ["[" + ", ".join(function_texts) + "]"]

@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content

if not isinstance(tools, list):
tools = [tools]

results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content

results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))

return results


TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
}


Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/webui/chatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def stream(
result = response

if isinstance(result, list):
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
Expand Down
Loading
Loading