Skip to content

chore: Upgrade default model to command-r-08-2024 #1691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class CohereChatGenerator:
from haystack.utils import Secret
from haystack_integrations.components.generators.cohere import CohereChatGenerator

client = CohereChatGenerator(model="command-r", api_key=Secret.from_env_var("COHERE_API_KEY"))
client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY"))
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client.run(messages)

Expand Down Expand Up @@ -278,7 +278,7 @@ def weather(city: str) -> str:

# Create and set up the pipeline
pipeline = Pipeline()
pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool]))
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
pipeline.connect("generator", "tool_invoker")

Expand All @@ -296,7 +296,7 @@ def weather(city: str) -> str:
def __init__(
self,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "command-r",
model: str = "command-r-08-2024",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
22 changes: 11 additions & 11 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_init_default(self, monkeypatch):

component = CohereChatGenerator()
assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"])
assert component.model == "command-r"
assert component.model == "command-r-08-2024"
assert component.streaming_callback is None
assert component.api_base_url == "https://api.cohere.com"
assert not component.generation_kwargs
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_to_dict_default(self, monkeypatch):
assert data == {
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model": "command-r",
"model": "command-r-08-2024",
"streaming_callback": None,
"api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"},
"api_base_url": "https://api.cohere.com",
Expand Down Expand Up @@ -116,15 +116,15 @@ def test_from_dict(self, monkeypatch):
data = {
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model": "command-r",
"model": "command-r-08-2024",
"api_base_url": "test-base-url",
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = CohereChatGenerator.from_dict(data)
assert component.model == "command-r"
assert component.model == "command-r-08-2024"
assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -135,7 +135,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
data = {
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model": "command-r",
"model": "command-r-08-2024",
"api_base_url": "test-base-url",
"api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_tools_use_old_way(self):
},
}
]
client = CohereChatGenerator(model="command-r")
client = CohereChatGenerator(model="command-r-08-2024")
response = client.run(
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
generation_kwargs={"tools": tools_schema},
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_tools_use_with_tools(self):
function=stock_price,
)
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
client = CohereChatGenerator(model="command-r")
client = CohereChatGenerator(model="command-r-08-2024")
response = client.run(
messages=initial_messages,
tools=[stock_price_tool],
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_live_run_with_tools_streaming(self):

initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
component = CohereChatGenerator(
model="command-r", # Cohere's model that supports tools
model="command-r-08-2024", # Cohere's model that supports tools
tools=[weather_tool],
streaming_callback=print_streaming_chunk,
)
Expand Down Expand Up @@ -384,7 +384,7 @@ def test_pipeline_with_cohere_chat_generator(self):
)

pipeline = Pipeline()
pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool]))
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))

pipeline.connect("generator", "tool_invoker")
Expand Down Expand Up @@ -416,7 +416,7 @@ def test_serde_in_pipeline(self, monkeypatch):

# Create generator with specific configuration
generator = CohereChatGenerator(
model="command-r",
model="command-r-08-2024",
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
tools=[tool],
Expand All @@ -437,7 +437,7 @@ def test_serde_in_pipeline(self, monkeypatch):
"generator": {
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", # noqa: E501
"init_parameters": {
"model": "command-r",
"model": "command-r-08-2024",
"api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"api_base_url": "https://api.cohere.com",
Expand Down