Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@
"cohere": {
"deps": [
"apache-airflow>=2.9.0",
"cohere>=4.37,<5"
"cohere>=5.13.4"
],
"devel-deps": [],
"plugins": [],
Expand Down
70 changes: 55 additions & 15 deletions providers/src/airflow/providers/cohere/hooks/cohere.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -15,25 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
import warnings
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

import cohere
from cohere.types import UserChatMessageV2

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings


logger = logging.getLogger(__name__)


class CohereHook(BaseHook):
"""
Use Cohere Python SDK to interact with Cohere platform.
Use Cohere Python SDK to interact with Cohere platform using API v2.

.. seealso:: https://docs.cohere.com/docs

:param conn_id: :ref:`Cohere connection id <howto/connection:cohere>`
:param timeout: Request timeout in seconds.
:param max_retries: Maximal number of retries for requests.
:param timeout: Request timeout in seconds. Optional.
:param max_retries: Maximal number of retries for requests. Deprecated, use request_options instead. Optional.
:param request_options: Dictionary for function-specific request configuration. Optional.
"""

conn_name_attr = "conn_id"
Expand All @@ -46,23 +58,45 @@ def __init__(
conn_id: str = default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: RequestOptions | None = None,
) -> None:
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options

if self.max_retries:
warnings.warn(
"Argument `max_retries` is deprecated. Use `request_options` dict for function-specific request configuration.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if self.request_options is None:
self.request_options = {"max_retries": self.max_retries}
else:
self.request_options.update({"max_retries": self.max_retries})

@cached_property
def get_conn(self) -> cohere.Client: # type: ignore[override]
def get_conn(self) -> cohere.ClientV2: # type: ignore[override]
conn = self.get_connection(self.conn_id)
return cohere.Client(
api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host
return cohere.ClientV2(
api_key=conn.password,
timeout=self.timeout,
base_url=conn.host or None,
)

def create_embeddings(
self, texts: list[str], model: str = "embed-multilingual-v2.0"
) -> list[list[float]]:
response = self.get_conn.embed(texts=texts, model=model)
self, texts: list[str], model: str = "embed-multilingual-v3.0"
) -> EmbedByTypeResponseEmbeddings:
logger.info("Creating embeddings with model: embed-multilingual-v3.0")
response = self.get_conn.embed(
texts=texts,
model=model,
input_type="search_document",
embedding_types=["float"],
request_options=self.request_options,
)
embeddings = response.embeddings
return embeddings

Expand All @@ -75,9 +109,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def test_connection(self) -> tuple[bool, str]:
def test_connection(
self,
model: str = "command-r-plus-08-2024",
messages: ChatMessages | None = None,
) -> tuple[bool, str]:
try:
self.get_conn.generate("Test", max_tokens=10)
return True, "Connection established"
if messages is None:
messages = [UserChatMessageV2(role="user", content="hello world!")]
self.get_conn.chat(model=model, messages=messages)
return True, "Connection successfully established."
except Exception as e:
return False, str(e)
return False, f"Unexpected error: {str(e)}"
25 changes: 23 additions & 2 deletions providers/src/airflow/providers/cohere/operators/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from airflow.providers.cohere.hooks.cohere import CohereHook

if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
from cohere.types import EmbedByTypeResponseEmbeddings

from airflow.utils.context import Context


Expand All @@ -41,6 +44,17 @@ class CohereEmbeddingOperator(BaseOperator):
information for Cohere. Defaults to "cohere_default".
:param timeout: Timeout in seconds for Cohere API.
:param max_retries: Number of times to retry before failing.
:param request_options: Request-specific configuration.
Fields:
- timeout_in_seconds: int. The number of seconds to await an API call before timing out.

- max_retries: int. The max number of retries to attempt if the API call fails.

- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict

- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict

- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict
"""

template_fields: Sequence[str] = ("input_text",)
Expand All @@ -51,6 +65,7 @@ def __init__(
conn_id: str = CohereHook.default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: RequestOptions | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -60,12 +75,18 @@ def __init__(
self.input_text = input_text
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options

@cached_property
def hook(self) -> CohereHook:
"""Return an instance of the CohereHook."""
return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries)
return CohereHook(
conn_id=self.conn_id,
timeout=self.timeout,
max_retries=self.max_retries,
request_options=self.request_options,
)

def execute(self, context: Context) -> list[list[float]]:
def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
"""Embed texts using Cohere embed services."""
return self.hook.create_embeddings(self.input_text)
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/cohere/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ integrations:

dependencies:
- apache-airflow>=2.9.0
- cohere>=4.37,<5
- cohere>=5.13.4

hooks:
- integration-name: Cohere
Expand Down
13 changes: 5 additions & 8 deletions providers/tests/cohere/hooks/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,16 @@ class TestCohereHook:

def test__get_api_key(self):
api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
with (
patch.object(
CohereHook,
"get_connection",
return_value=Connection(conn_type="cohere", password=api_key, host=api_url),
return_value=Connection(conn_type="cohere", password=api_key, host=base_url),
),
patch("cohere.Client") as client,
patch("cohere.ClientV2") as client,
):
hook = CohereHook(timeout=timeout, max_retries=max_retries)
hook = CohereHook(timeout=timeout)
_ = hook.get_conn
client.assert_called_once_with(
api_key=api_key, timeout=timeout, max_retries=max_retries, api_url=api_url
)
client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url)
18 changes: 10 additions & 8 deletions providers/tests/cohere/operators/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@patch("airflow.providers.cohere.hooks.cohere.CohereHook.get_connection")
@patch("cohere.Client")
@patch("cohere.ClientV2")
def test_cohere_embedding_operator(cohere_client, get_connection):
"""
Test Cohere client is getting called with the correct key and that
Expand All @@ -35,22 +35,24 @@ class resp:
embeddings = embedded_obj

api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
texts = ["On Kernel-Target Alignment. We describe a family of global optimization procedures"]
request_options = None

get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=api_url)
get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url)
client_obj = MagicMock()
cohere_client.return_value = client_obj
client_obj.embed.return_value = resp

op = CohereEmbeddingOperator(
task_id="embed", conn_id="some_conn", input_text=texts, timeout=timeout, max_retries=max_retries
task_id="embed",
conn_id="some_conn",
input_text=texts,
timeout=timeout,
request_options=request_options,
)

val = op.execute(context={})
cohere_client.assert_called_once_with(
api_key=api_key, api_url=api_url, timeout=timeout, max_retries=max_retries
)
cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url, timeout=timeout)
assert val == embedded_obj