Skip to content

Commit

Permalink
Add openai tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmi committed Dec 10, 2024
1 parent 1b647e2 commit a1cadba
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 111 deletions.
2 changes: 1 addition & 1 deletion gptme/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"user": "green",
"assistant": "green",
"system": "grey42",
"tool_result": "red",
"tool_result": "grey42",
}

# colors wrapped in \001 and \002 to inform readline about non-printable characters
Expand Down
21 changes: 10 additions & 11 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
}
]
yield modified_message
# Find tool_use occurrence and format them as expected
# Find tool_use occurrences and format them as expected
elif message["role"] == "assistant":
modified_message = dict(message)
text = ""
Expand Down Expand Up @@ -206,10 +206,6 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
yield message


def _handle_files(message_dicts: list[dict]) -> list[dict]:
return [_process_file(message_dict) for message_dict in message_dicts]


def _process_file(message_dict: dict) -> dict:
message_content = message_dict["content"]

Expand Down Expand Up @@ -325,7 +321,7 @@ def _transform_system_messages(
return messages, system_messages


def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
def _parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
required = []
properties = {}

Expand Down Expand Up @@ -353,7 +349,7 @@ def _spec2tool(
return {
"name": name,
"description": spec.get_instructions("tool"),
"input_schema": parameters2dict(spec.parameters),
"input_schema": _parameters2dict(spec.parameters),
}


Expand Down Expand Up @@ -395,7 +391,13 @@ def _prepare_messages_for_api(
messages, system_messages = _transform_system_messages(messages)

# Handle files and convert to dicts
messages_dicts = _handle_files(msgs2dicts(messages))
messages_dicts = (_process_file(f) for f in msgs2dicts(messages))

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
Expand Down Expand Up @@ -429,7 +431,4 @@ def _prepare_messages_for_api(
assert isinstance(msgp["content"], list)
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"}

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

return messages_dicts_new, system_messages, tools_dict
163 changes: 107 additions & 56 deletions gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..constants import TEMPERATURE, TOP_P
from ..message import Message, msgs2dicts
from ..tools.base import Parameter, ToolSpec, ToolUse
from .models import Provider, get_model
from .models import ModelMeta, Provider, get_model

if TYPE_CHECKING:
# noreorder
Expand Down Expand Up @@ -114,18 +114,11 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
# top_p controls diversity, temperature controls randomness
assert openai, "LLM not initialized"

is_o1 = model.startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))

messages_dicts: Iterable[dict] = _handle_files(msgs2dicts(messages))

from openai import NOT_GIVEN # fmt: skip

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
is_o1 = model.startswith("o1")

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)
messages_dicts, tools_dict = _prepare_messages_for_api(messages, tools)

response = openai.chat.completions.create(
model=model,
Expand All @@ -148,18 +141,11 @@ def stream(
assert openai, "LLM not initialized"
stop_reason = None

is_o1 = model.startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))

messages_dicts: Iterable[dict] = _handle_files(msgs2dicts(messages))

from openai import NOT_GIVEN # fmt: skip

tools_dict = [_spec2tool(tool) for tool in tools] if tools else None
is_o1 = model.startswith("o1")

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)
messages_dicts, tools_dict = _prepare_messages_for_api(messages, tools)

for chunk_raw in openai.chat.completions.create(
model=model,
Expand Down Expand Up @@ -211,54 +197,97 @@ def stream(
logger.debug(f"Stop reason: {stop_reason}")


def _handle_files(msgs: list[dict]) -> list[dict]:
return [_process_file(msg) for msg in msgs]
def _handle_tool_results(message_dicts: Iterable[dict]) -> Iterable[dict]:
for message in message_dicts:
if message["role"] == "tool_result":
modified_message = dict(message)
modified_message["role"] = "system"
yield modified_message
else:
yield message


def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool_result ass expected by the model
# Format tool_result as expected by the model
if message["role"] == "tool_result":
modified_message = dict(message)
modified_message["role"] = "tool"
modified_message["tool_call_id"] = modified_message.pop("call_id")
if "call_id" in modified_message:
modified_message["role"] = "tool"
modified_message["tool_call_id"] = modified_message.pop("call_id")
else:
# Not using tool call
modified_message["role"] = "system"
yield modified_message
# Find tool_use occurrence and format them as expected
# Find tool_use occurrences and format them as expected
elif message["role"] == "assistant":
modified_message = dict(message)

tooluses = [
tooluse
for tooluse in ToolUse.iter_from_content(modified_message["content"])
if tooluse.is_runnable
]
if not tooluses:
yield message

# At that point we should always have exactly one tooluse
# Because we remove the previous ones as soon as we encounter
# them so we can't have more.
assert len(tooluses) == 1
tooluse = tooluses[0]

del modified_message["content"]
modified_message["tool_calls"] = {
"id": tooluse.call_id or "",
"type": "function",
"function": {
"name": tooluse.tool,
"arguments": json.dumps(tooluse.kwargs or {}),
},
}
text = ""
content = []
tool_calls = []

# Some content are text, some are list
if isinstance(message["content"], list):
message_parts = message["content"]
else:
message_parts = [{"type": "text", "text": message["content"]}]

for message_part in message_parts:
if message_part["type"] != "text":
content.append(message_part)
continue

# For a message part of type `text`` we try to extract the tool_uses
# We search line by line to stop as soon as we have a tool call
# It makes it easier to split in multiple parts.
for line in message_part["text"].split("\n"):
text += line + "\n"

tooluses = [
tooluse
for tooluse in ToolUse.iter_from_content(text)
if tooluse.is_runnable
]
if not tooluses:
continue

# At that point we should always have exactly one tooluse
# Because we remove the previous ones as soon as we encounter
# them so we can't have more.
assert len(tooluses) == 1
tooluse = tooluses[0]
before_tool = text[: tooluse.start]

if before_tool:
content.append({"type": "text", "text": before_tool})

tool_calls.append(
{
"id": tooluse.call_id or "",
"type": "function",
"function": {
"name": tooluse.tool,
"arguments": json.dumps(tooluse.kwargs or {}),
},
}
)
# The text is emptied to start over with the next lines if any.
text = ""

if content:
modified_message["content"] = content
else:
del modified_message["content"]
if tool_calls:
modified_message["tool_calls"] = tool_calls

yield modified_message
else:
yield message


def _process_file(msg: dict) -> dict:
def _process_file(msg: dict, model: ModelMeta) -> dict:
message_content = msg["content"]
model = get_model()
if model.provider == "deepseek":
# deepseek does not support files
return msg
Expand Down Expand Up @@ -327,7 +356,7 @@ def _process_file(msg: dict) -> dict:
return msg


def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
def _parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
required = []
properties = {}

Expand All @@ -344,7 +373,7 @@ def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
}


def _spec2tool(spec: ToolSpec) -> "ChatCompletionToolParam":
def _spec2tool(spec: ToolSpec, model: ModelMeta) -> "ChatCompletionToolParam":
name = spec.name
if spec.block_types:
name = spec.block_types[0]
Expand All @@ -358,16 +387,38 @@ def _spec2tool(spec: ToolSpec) -> "ChatCompletionToolParam":
)
description = description[:1024]

provider = get_provider()
if provider in ["openai", "azure", "openrouter", "local"]:
if model.provider in ["openai", "azure", "openrouter", "local"]:
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters2dict(spec.parameters),
"parameters": _parameters2dict(spec.parameters),
# "strict": False, # not supported by OpenRouter
},
}
else:
raise ValueError("Provider doesn't support tools API")


def _prepare_messages_for_api(
messages: list[Message], tools: list[ToolSpec] | None
) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]:
model = get_model()

is_o1 = model.model.startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))

messages_dicts: Iterable[dict] = (
_process_file(msg, model) for msg in msgs2dicts(messages)
)

tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)
else:
messages_dicts = _handle_tool_results(messages_dicts)

return list(messages_dicts), tools_dict
Loading

0 comments on commit a1cadba

Please sign in to comment.