Skip to content

Add cohere client #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ lint-fix = [
]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "cohere"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand Down
5 changes: 3 additions & 2 deletions src/strands/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
This package includes an abstract base Model class along with concrete implementations for specific providers.
"""

from . import bedrock
from . import bedrock, cohere
from .bedrock import BedrockModel
from .cohere import CohereModel

__all__ = ["bedrock", "BedrockModel"]
__all__ = ["bedrock", "BedrockModel", "cohere", "CohereModel"]
34 changes: 34 additions & 0 deletions src/strands/models/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Cohere model provider using OpenAI compatibility API.

- Docs: https://docs.cohere.com/docs/compatibility-api
"""

import logging

from typing_extensions import Unpack

from .openai import OpenAIModel

logger = logging.getLogger(__name__)


class CohereModel(OpenAIModel):
"""Cohere model provider implementation using OpenAI compatibility API."""

class CohereConfig(OpenAIModel.OpenAIConfig):
"""Configuration options for Cohere models."""

pass

def __init__(self, api_key: str, **model_config: Unpack[CohereConfig]):
"""Initialize Cohere provider instance.

Args:
api_key: Cohere API key.
**model_config: Configuration options for the Cohere model.
"""
client_args = {
"base_url": "https://api.cohere.com/compatibility/v1",
"api_key": api_key,
}
super().__init__(client_args=client_args, **model_config)
43 changes: 43 additions & 0 deletions tests-integ/test_model_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

import pytest

import strands
from strands import Agent
from strands.models.cohere import CohereModel


@pytest.fixture
def model():
return CohereModel(
model_id="command-a-03-2025",
api_key=os.getenv("CO_API_KEY"),
)


@pytest.fixture
def tools():
@strands.tool
def tool_time() -> str:
return "12:00"

@strands.tool
def tool_weather() -> str:
return "sunny"

return [tool_time, tool_weather]


@pytest.fixture
def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.mark.skipif(
"CO_API_KEY" not in os.environ,
reason="CO_API_KEY environment variable missing",
)
def test_agent(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()
assert all(string in text for string in ["12:00", "sunny"])
143 changes: 143 additions & 0 deletions tests/strands/models/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import unittest.mock

import pytest

import strands
from strands.models.cohere import CohereModel


@pytest.fixture
def cohere_client_cls():
with unittest.mock.patch.object(strands.models.openai, "openai") as mock_openai_mod:
yield mock_openai_mod.OpenAI


@pytest.fixture
def cohere_client(cohere_client_cls):
return cohere_client_cls.return_value


@pytest.fixture
def model_id():
return "m1"


@pytest.fixture
def model(cohere_client, model_id):
_ = cohere_client
return CohereModel(api_key="k1", model_id=model_id)


@pytest.fixture
def messages():
return [{"role": "user", "content": [{"text": "test"}]}]


@pytest.fixture
def system_prompt():
return "s1"


def test__init__(cohere_client_cls, model_id):
model = CohereModel(api_key="k1", model_id=model_id, params={"max_tokens": 1})
tru_config = model.get_config()
exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
assert tru_config == exp_config
cohere_client_cls.assert_called_once_with(base_url="https://api.cohere.ai/compatibility/v1", api_key="k1")


def test_update_config(model, model_id):
model.update_config(model_id=model_id)
tru_model_id = model.get_config().get("model_id")
exp_model_id = model_id
assert tru_model_id == exp_model_id


def test_stream(cohere_client, model):
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_1 = unittest.mock.Mock(
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1]
)
mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
mock_delta_2 = unittest.mock.Mock(
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2]
)
mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None)
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)])
mock_event_4 = unittest.mock.Mock()
cohere_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
response = model.stream(request)
tru_events = list(response)
exp_events = [
{"chunk_type": "message_start"},
{"chunk_type": "content_start", "data_type": "text"},
{"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
{"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
{"chunk_type": "content_stop", "data_type": "text"},
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
{"chunk_type": "content_stop", "data_type": "tool"},
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1},
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
{"chunk_type": "content_stop", "data_type": "tool"},
{"chunk_type": "message_stop", "data": "tool_calls"},
{"chunk_type": "metadata", "data": mock_event_4.usage},
]
assert tru_events == exp_events
cohere_client.chat.completions.create.assert_called_once_with(**request)


def test_stream_empty(cohere_client, model):
mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
mock_event_3 = unittest.mock.Mock()
mock_event_4 = unittest.mock.Mock(usage=mock_usage)
cohere_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
response = model.stream(request)
tru_events = list(response)
exp_events = [
{"chunk_type": "message_start"},
{"chunk_type": "content_start", "data_type": "text"},
{"chunk_type": "content_stop", "data_type": "text"},
{"chunk_type": "message_stop", "data": "stop"},
{"chunk_type": "metadata", "data": mock_usage},
]
assert tru_events == exp_events
cohere_client.chat.completions.create.assert_called_once_with(**request)


def test_stream_with_empty_choices(cohere_client, model):
mock_delta = unittest.mock.Mock(content="content", tool_calls=None)
mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
mock_event_1 = unittest.mock.Mock(spec=[])
mock_event_2 = unittest.mock.Mock(choices=[])
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
mock_event_5 = unittest.mock.Mock(usage=mock_usage)
cohere_client.chat.completions.create.return_value = iter(
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]
)
request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]}
response = model.stream(request)
tru_events = list(response)
exp_events = [
{"chunk_type": "message_start"},
{"chunk_type": "content_start", "data_type": "text"},
{"chunk_type": "content_delta", "data_type": "text", "data": "content"},
{"chunk_type": "content_delta", "data_type": "text", "data": "content"},
{"chunk_type": "content_stop", "data_type": "text"},
{"chunk_type": "message_stop", "data": "stop"},
{"chunk_type": "metadata", "data": mock_usage},
]
assert tru_events == exp_events
cohere_client.chat.completions.create.assert_called_once_with(**request)