Skip to content

Commit

Permalink
feat: improve tool response
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmi committed Dec 13, 2024
1 parent a61aa28 commit 43785e5
Show file tree
Hide file tree
Showing 9 changed files with 618 additions and 75 deletions.
14 changes: 8 additions & 6 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
@click.option(
"--tool-format",
"tool_format",
default="markdown",
default=None,
help="Tool parsing method. Can be 'markdown', 'xml', 'tool'. (experimental)",
)
@click.option(
Expand Down Expand Up @@ -149,7 +149,7 @@ def main(
name: str,
model: str | None,
tool_allowlist: list[str] | None,
tool_format: ToolFormat,
tool_format: ToolFormat | None,
stream: bool,
verbose: bool,
no_confirm: bool,
Expand Down Expand Up @@ -187,8 +187,10 @@ def main(

config = get_config()

tool_format = tool_format or config.get_env("TOOL_FORMAT") or "markdown"
set_tool_format(tool_format)
selected_tool_format: ToolFormat = (
tool_format or config.get_env("TOOL_FORMAT") or "markdown" # type: ignore
)
set_tool_format(selected_tool_format)

# early init tools to generate system prompt
init_tools(frozenset(tool_allowlist) if tool_allowlist else None)
Expand All @@ -198,7 +200,7 @@ def main(
get_prompt(
prompt_system,
interactive=interactive,
tool_format=tool_format,
tool_format=selected_tool_format,
)
]

Expand Down Expand Up @@ -272,7 +274,7 @@ def main(
show_hidden,
workspace_path,
tool_allowlist,
tool_format,
selected_tool_format,
)
except RuntimeError as e:
logger.error(e)
Expand Down
110 changes: 96 additions & 14 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
TypedDict,
cast,
)
from collections.abc import Iterable

from ..constants import TEMPERATURE, TOP_P
from ..message import Message, msgs2dicts
from ..tools.base import Parameter, ToolSpec
from ..tools.base import Parameter, ToolSpec, ToolUse

if TYPE_CHECKING:
# noreorder
Expand Down Expand Up @@ -66,9 +67,17 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
)
content = response.content
logger.debug(response.usage)
assert content
assert len(content) == 1
return content[0].text # type: ignore

parsed_block = []
for block in content:
if block.type == "text":
parsed_block.append(block.text)
elif block.type == "tool_use":
parsed_block.append(f"\n@{block.name}({block.id}): {block.input}")
else:
logger.warning("Unknown block: %s", str(block))

return "\n".join(parsed_block)


def stream(
Expand Down Expand Up @@ -99,7 +108,7 @@ def stream(
block = chunk.content_block
if isinstance(block, anthropic.types.ToolUseBlock):
tool_use = block
yield f"\n@{tool_use.name}: "
yield f"\n@{tool_use.name}({tool_use.id}): "
elif isinstance(block, anthropic.types.TextBlock):
if block.text:
logger.warning("unexpected text block: %s", block.text)
Expand Down Expand Up @@ -135,8 +144,78 @@ def stream(
pass


def _handle_files(message_dicts: list[dict]) -> list[dict]:
return [_process_file(message_dict) for message_dict in message_dicts]
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool result as expected by the model
if message["role"] == "system" and "call_id" in message:
modified_message = dict(message)
modified_message["role"] = "user"
modified_message["content"] = [
{
"type": "tool_result",
"content": modified_message["content"],
"tool_use_id": modified_message.pop("call_id"),
}
]
yield modified_message
# Find tool_use occurrences and format them as expected
elif message["role"] == "assistant":
modified_message = dict(message)
text = ""
content = []

# Some content are text, some are list
if isinstance(message["content"], list):
message_parts = message["content"]
else:
message_parts = [{"type": "text", "text": message["content"]}]

for message_part in message_parts:
if message_part["type"] != "text":
content.append(message_part)
continue

# For a message part of type `text`` we try to extract the tool_uses
# We search line by line to stop as soon as we have a tool call
# It makes it easier to split in multiple parts.
for line in message_part["text"].split("\n"):
text += line + "\n"

tooluses = [
tooluse
for tooluse in ToolUse.iter_from_content(text)
if tooluse.is_runnable
]
if not tooluses:
continue

# At that point we should always have exactly one tooluse
# Because we remove the previous ones as soon as we encounter
# them so we can't have more.
assert len(tooluses) == 1
tooluse = tooluses[0]
before_tool = text[: tooluse.start]

if before_tool:
content.append({"type": "text", "text": before_tool})

content.append(
{
"type": "tool_use",
"id": tooluse.call_id or "",
"name": tooluse.tool,
"input": tooluse.kwargs or {},
}
)
# The text is emptied to start over with the next lines if any.
text = ""

if content:
modified_message["content"] = content

yield modified_message
else:
yield message


def _process_file(message_dict: dict) -> dict:
Expand Down Expand Up @@ -224,7 +303,7 @@ def _transform_system_messages(

# for any subsequent system messages, transform them into a <system> message
for i, message in enumerate(messages):
if message.role == "system":
if message.role == "system" and message.call_id is None:
messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
Expand Down Expand Up @@ -254,7 +333,7 @@ def _transform_system_messages(
return messages, system_messages


def parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
def _parameters2dict(parameters: list[Parameter]) -> dict[str, object]:
required = []
properties = {}

Expand Down Expand Up @@ -282,7 +361,7 @@ def _spec2tool(
return {
"name": name,
"description": spec.get_instructions("tool"),
"input_schema": parameters2dict(spec.parameters),
"input_schema": _parameters2dict(spec.parameters),
}


Expand Down Expand Up @@ -324,7 +403,13 @@ def _prepare_messages_for_api(
messages, system_messages = _transform_system_messages(messages)

# Handle files and convert to dicts
messages_dicts = _handle_files(msgs2dicts(messages))
messages_dicts = (_process_file(f) for f in msgs2dicts(messages))

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)

# Apply cache control to optimize performance
messages_dicts_new: list[PromptCachingBetaMessageParam] = []
Expand Down Expand Up @@ -361,7 +446,4 @@ def _prepare_messages_for_api(
assert isinstance(msgp["content"], list)
msgp["content"][-1]["cache_control"] = {"type": "ephemeral"}

# Prepare tools
tools_dict = [_spec2tool(tool) for tool in tools] if tools else None

return messages_dicts_new, system_messages, tools_dict
Loading

0 comments on commit 43785e5

Please sign in to comment.