Skip to content

Add Cohere API agent #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ The following table shows the required environment variables to set depending on

The following table shows the required environment variables to set depending on the agent type:

| Agent Type | ANTHROPIC_API_KEY | HF_TOKEN | OPENAI_API_KEY | OPENAI_BASE_URL | GOOGLE_CLOUD_PROJECT or CLOUD_ML_PROJECT_ID | GOOGLE_CLOUD_REGION or CLOUD_ML_REGION |
| --------------------- | :---------------: | :------: | :------------: | :-------------: | :-----------------------------------------: | :------------------------------------: |
| Claude_3_Haiku | ✅ | | | | | |
| Claude_3_Opus | ✅ | | | | | |
| Claude_3_Sonnet | ✅ | | | | | |
| Cli | | | | | | |
| Cohere_Command_R | | | | ✅ | | |
| Cohere_Command_R_Plus | | | | ✅ | | |
| GPT_3_5_0125 | | | ✅ | | | |
| GPT_4_0125 | | | ✅ | | | |
| GPT_4_o_2024_05_13 | | | ✅ | | | |
| Gemini_1_0 | | | | | ✅ | ✅ |
| Gemini_1_5 | | | | | ✅ | ✅ |
| Gemini_1_5_Flash | | | | | ✅ | ✅ |
| Gorilla | | | | ✅ | | |
| Hermes | | | | ✅ | | |
| Mistral | | | | ✅ | | |
| Agent Type | ANTHROPIC_API_KEY | CO_API_KEY | HF_TOKEN | OPENAI_API_KEY | OPENAI_BASE_URL | GOOGLE_CLOUD_PROJECT or CLOUD_ML_PROJECT_ID | GOOGLE_CLOUD_REGION or CLOUD_ML_REGION |
| --------------------- | :---------------: |------------| :------: | :------------: |:---------------:| :-----------------------------------------: | :------------------------------------: |
| Claude_3_Haiku | ✅ | | | | | | |
| Claude_3_Opus | ✅ | | | | | | |
| Claude_3_Sonnet | ✅ | | | | | | |
| Cli | | | | | | | |
| Cohere_Command_R | | | | | | | |
| Cohere_Command_R_Plus | | | | | | | |
| GPT_3_5_0125 | | | | ✅ | | | |
| GPT_4_0125 | | | | ✅ | | | |
| GPT_4_o_2024_05_13 | | | | ✅ | | | |
| Gemini_1_0 | | | | | | ✅ | ✅ |
| Gemini_1_5 | | | | | | ✅ | ✅ |
| Gemini_1_5_Flash | | | | | | ✅ | ✅ |
| Gorilla | | | | | ✅ | | |
| Hermes | | | | | ✅ | | |
| Mistral | | | | | ✅ | | |

The search tools in the ToolSandbox use [RapidAPI](https://rapidapi.com/hub) so in order to run those scenarios you need to have an API key and expose it in an environment variable called `RAPID_API_KEY`. Using models from the Gemini family requires setting up google authentication, e.g. by using [application default credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc).

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"tree-sitter-languages==1.10.2",
"typing_extensions==4.12.2",
"vertexai==1.49.0",
"cohere==5.8.1",
]

[project.urls]
Expand Down
11 changes: 3 additions & 8 deletions tests/roles/cohere_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Unit tests for the Cohere agent."""

from typing import cast

import pytest
from openai.types.chat import ChatCompletionToolParam

from tool_sandbox.common.tool_conversion import convert_to_openai_tool
from tool_sandbox.common.tool_discovery import ToolBackend, get_all_tools
from tool_sandbox.roles.cohere_agent import to_cohere_tool
from tool_sandbox.roles.cohere_api_agent import to_cohere_tool


@pytest.mark.parametrize("tool_backend", ToolBackend)
def test_tool_conversion(tool_backend: ToolBackend) -> None:
"""Ensure that all our tools can be converted to the Cohere format."""
name_to_tool = get_all_tools(preferred_tool_backend=tool_backend)
for tool in name_to_tool.values():
openai_tool = convert_to_openai_tool(tool)
assert to_cohere_tool(cast(ChatCompletionToolParam, openai_tool)) is not None
for name, tool in name_to_tool.items():
assert to_cohere_tool(name=name, tool=tool) is not None
10 changes: 4 additions & 6 deletions tool_sandbox/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from tool_sandbox.roles.base_role import BaseRole
from tool_sandbox.roles.cli_role import CliAgent, CliUser
from tool_sandbox.roles.cohere_agent import CohereAgent
from tool_sandbox.roles.cohere_api_agent import CohereAPIAgent
from tool_sandbox.roles.execution_environment import ExecutionEnvironment
from tool_sandbox.roles.gemini_agent import GeminiAgent
from tool_sandbox.roles.gorilla_api_agent import GorillaAPIAgent
Expand Down Expand Up @@ -84,11 +84,9 @@ class RoleImplType(StrEnum):
model_name="gemini-1.5-flash-001"
),
RoleImplType.Cli: CliAgent,
RoleImplType.Cohere_Command_R: lambda: CohereAgent(
model_name="CohereForAI/c4ai-command-r-v01"
),
RoleImplType.Cohere_Command_R_Plus: lambda: CohereAgent(
model_name="CohereForAI/c4ai-command-r-plus"
RoleImplType.Cohere_Command_R: lambda: CohereAPIAgent(model_name="command-r"),
RoleImplType.Cohere_Command_R_Plus: lambda: CohereAPIAgent(
model_name="command-r-plus"
),
RoleImplType.Unhelpful: UnhelpfulAgent,
}
Expand Down
Loading