-
Notifications
You must be signed in to change notification settings - Fork 558
feat(ai-monitoring): Cohere integration #3055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4c14156
Cohere integration
colin-sentry 98fc4e8
Fix lint
colin-sentry cb752c7
Fix bug with model ID not being pulled
colin-sentry 5e183cc
Exclude known models from langchain
colin-sentry 80ebc33
tox.ini
colin-sentry 6b8149a
Merge branch 'master' into cohere
antonpirker 12ba48b
Removed print statement
antonpirker 36ba134
Apply suggestions from code review
colin-sentry 2a69861
Merge branch 'master' into cohere
colin-sentry File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
"arq", | ||
"beam", | ||
"celery", | ||
"cohere", | ||
"huey", | ||
"langchain", | ||
"openai", | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
from functools import wraps | ||
|
||
from sentry_sdk import consts | ||
from sentry_sdk._types import TYPE_CHECKING | ||
from sentry_sdk.ai.monitoring import record_token_usage | ||
from sentry_sdk.consts import SPANDATA | ||
from sentry_sdk.ai.utils import set_data_normalized | ||
|
||
if TYPE_CHECKING: | ||
from typing import Any, Callable, Iterator | ||
from sentry_sdk.tracing import Span | ||
|
||
import sentry_sdk | ||
from sentry_sdk.scope import should_send_default_pii | ||
from sentry_sdk.integrations import DidNotEnable, Integration | ||
from sentry_sdk.utils import ( | ||
capture_internal_exceptions, | ||
event_from_exception, | ||
ensure_integration_enabled, | ||
) | ||
|
||
try: | ||
from cohere.client import Client | ||
from cohere.base_client import BaseCohere | ||
from cohere import ChatStreamEndEvent, NonStreamedChatResponse | ||
|
||
if TYPE_CHECKING: | ||
from cohere import StreamedChatResponse | ||
except ImportError: | ||
raise DidNotEnable("Cohere not installed") | ||
|
||
|
||
COLLECTED_CHAT_PARAMS = { | ||
"model": SPANDATA.AI_MODEL_ID, | ||
"k": SPANDATA.AI_TOP_K, | ||
"p": SPANDATA.AI_TOP_P, | ||
"seed": SPANDATA.AI_SEED, | ||
"frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY, | ||
"presence_penalty": SPANDATA.AI_PRESENCE_PENALTY, | ||
"raw_prompting": SPANDATA.AI_RAW_PROMPTING, | ||
} | ||
|
||
COLLECTED_PII_CHAT_PARAMS = { | ||
"tools": SPANDATA.AI_TOOLS, | ||
"preamble": SPANDATA.AI_PREAMBLE, | ||
} | ||
|
||
COLLECTED_CHAT_RESP_ATTRS = { | ||
"generation_id": "ai.generation_id", | ||
"is_search_required": "ai.is_search_required", | ||
"finish_reason": "ai.finish_reason", | ||
} | ||
|
||
COLLECTED_PII_CHAT_RESP_ATTRS = { | ||
"citations": "ai.citations", | ||
"documents": "ai.documents", | ||
"search_queries": "ai.search_queries", | ||
"search_results": "ai.search_results", | ||
"tool_calls": "ai.tool_calls", | ||
} | ||
|
||
|
||
class CohereIntegration(Integration): | ||
identifier = "cohere" | ||
|
||
def __init__(self, include_prompts=True): | ||
# type: (CohereIntegration, bool) -> None | ||
self.include_prompts = include_prompts | ||
|
||
@staticmethod | ||
def setup_once(): | ||
# type: () -> None | ||
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) | ||
Client.embed = _wrap_embed(Client.embed) | ||
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) | ||
|
||
|
||
def _capture_exception(exc): | ||
# type: (Any) -> None | ||
event, hint = event_from_exception( | ||
exc, | ||
client_options=sentry_sdk.get_client().options, | ||
mechanism={"type": "cohere", "handled": False}, | ||
) | ||
sentry_sdk.capture_event(event, hint=hint) | ||
|
||
|
||
def _wrap_chat(f, streaming): | ||
# type: (Callable[..., Any], bool) -> Callable[..., Any] | ||
|
||
def collect_chat_response_fields(span, res, include_pii): | ||
# type: (Span, NonStreamedChatResponse, bool) -> None | ||
if include_pii: | ||
if hasattr(res, "text"): | ||
set_data_normalized( | ||
span, | ||
SPANDATA.AI_RESPONSES, | ||
[res.text], | ||
) | ||
for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS: | ||
if hasattr(res, pii_attr): | ||
set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr)) | ||
|
||
for attr in COLLECTED_CHAT_RESP_ATTRS: | ||
if hasattr(res, attr): | ||
set_data_normalized(span, "ai." + attr, getattr(res, attr)) | ||
|
||
if hasattr(res, "meta"): | ||
if hasattr(res.meta, "billed_units"): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.billed_units.input_tokens, | ||
completion_tokens=res.meta.billed_units.output_tokens, | ||
) | ||
elif hasattr(res.meta, "tokens"): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.tokens.input_tokens, | ||
completion_tokens=res.meta.tokens.output_tokens, | ||
) | ||
|
||
if hasattr(res.meta, "warnings"): | ||
set_data_normalized(span, "ai.warnings", res.meta.warnings) | ||
|
||
@wraps(f) | ||
@ensure_integration_enabled(CohereIntegration, f) | ||
def new_chat(*args, **kwargs): | ||
# type: (*Any, **Any) -> Any | ||
if "message" not in kwargs: | ||
return f(*args, **kwargs) | ||
|
||
if not isinstance(kwargs.get("message"), str): | ||
return f(*args, **kwargs) | ||
|
||
message = kwargs.get("message") | ||
|
||
span = sentry_sdk.start_span( | ||
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE, | ||
description="cohere.client.Chat", | ||
) | ||
span.__enter__() | ||
try: | ||
res = f(*args, **kwargs) | ||
except Exception as e: | ||
_capture_exception(e) | ||
span.__exit__(None, None, None) | ||
raise e from None | ||
|
||
integration = sentry_sdk.get_client().get_integration(CohereIntegration) | ||
|
||
with capture_internal_exceptions(): | ||
if should_send_default_pii() and integration.include_prompts: | ||
set_data_normalized( | ||
span, | ||
SPANDATA.AI_INPUT_MESSAGES, | ||
list( | ||
map( | ||
lambda x: { | ||
"role": getattr(x, "role", "").lower(), | ||
"content": getattr(x, "message", ""), | ||
}, | ||
kwargs.get("chat_history", []), | ||
) | ||
) | ||
+ [{"role": "user", "content": message}], | ||
) | ||
for k, v in COLLECTED_PII_CHAT_PARAMS.items(): | ||
if k in kwargs: | ||
set_data_normalized(span, v, kwargs[k]) | ||
|
||
for k, v in COLLECTED_CHAT_PARAMS.items(): | ||
if k in kwargs: | ||
set_data_normalized(span, v, kwargs[k]) | ||
set_data_normalized(span, SPANDATA.AI_STREAMING, False) | ||
|
||
if streaming: | ||
old_iterator = res | ||
|
||
def new_iterator(): | ||
# type: () -> Iterator[StreamedChatResponse] | ||
|
||
with capture_internal_exceptions(): | ||
for x in old_iterator: | ||
if isinstance(x, ChatStreamEndEvent): | ||
collect_chat_response_fields( | ||
span, | ||
x.response, | ||
include_pii=should_send_default_pii() | ||
and integration.include_prompts, | ||
) | ||
yield x | ||
|
||
span.__exit__(None, None, None) | ||
|
||
return new_iterator() | ||
elif isinstance(res, NonStreamedChatResponse): | ||
collect_chat_response_fields( | ||
span, | ||
res, | ||
include_pii=should_send_default_pii() | ||
and integration.include_prompts, | ||
) | ||
span.__exit__(None, None, None) | ||
else: | ||
set_data_normalized(span, "unknown_response", True) | ||
span.__exit__(None, None, None) | ||
return res | ||
|
||
return new_chat | ||
|
||
|
||
def _wrap_embed(f): | ||
# type: (Callable[..., Any]) -> Callable[..., Any] | ||
|
||
@wraps(f) | ||
@ensure_integration_enabled(CohereIntegration, f) | ||
def new_embed(*args, **kwargs): | ||
# type: (*Any, **Any) -> Any | ||
with sentry_sdk.start_span( | ||
op=consts.OP.COHERE_EMBEDDINGS_CREATE, | ||
description="Cohere Embedding Creation", | ||
) as span: | ||
integration = sentry_sdk.get_client().get_integration(CohereIntegration) | ||
if "texts" in kwargs and ( | ||
should_send_default_pii() and integration.include_prompts | ||
): | ||
if isinstance(kwargs["texts"], str): | ||
set_data_normalized(span, "ai.texts", [kwargs["texts"]]) | ||
elif ( | ||
isinstance(kwargs["texts"], list) | ||
and len(kwargs["texts"]) > 0 | ||
and isinstance(kwargs["texts"][0], str) | ||
): | ||
set_data_normalized( | ||
span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"] | ||
) | ||
|
||
if "model" in kwargs: | ||
set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"]) | ||
try: | ||
res = f(*args, **kwargs) | ||
except Exception as e: | ||
_capture_exception(e) | ||
raise e from None | ||
if ( | ||
hasattr(res, "meta") | ||
and hasattr(res.meta, "billed_units") | ||
and hasattr(res.meta.billed_units, "input_tokens") | ||
): | ||
record_token_usage( | ||
span, | ||
prompt_tokens=res.meta.billed_units.input_tokens, | ||
total_tokens=res.meta.billed_units.input_tokens, | ||
) | ||
return res | ||
|
||
return new_embed |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import pytest | ||
|
||
pytest.importorskip("cohere") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.