Skip to content

Improve token usage recording #4566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,40 @@ async def async_wrapped(*args, **kwargs):


def record_token_usage(
span, prompt_tokens=None, completion_tokens=None, total_tokens=None
span,
input_tokens=None,
input_tokens_cached=None,
output_tokens=None,
output_tokens_reasoning=None,
total_tokens=None,
):
# type: (Span, Optional[int], Optional[int], Optional[int]) -> None
# type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None

# TODO: move pipeline name elsewhere
ai_pipeline_name = get_ai_pipeline_name()
if ai_pipeline_name:
span.set_data(SPANDATA.AI_PIPELINE_NAME, ai_pipeline_name)
if prompt_tokens is not None:
span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens)
if completion_tokens is not None:
span.set_measurement("ai_completion_tokens_used", value=completion_tokens)
if (
total_tokens is None
and prompt_tokens is not None
and completion_tokens is not None
):
total_tokens = prompt_tokens + completion_tokens

if input_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)

if input_tokens_cached is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
input_tokens_cached,
)

if output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)

if output_tokens_reasoning is not None:
span.set_data(
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
output_tokens_reasoning,
)

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

if total_tokens is not None:
span.set_measurement("ai_total_tokens_used", total_tokens)
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
15 changes: 13 additions & 2 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def _calculate_token_usage(result, span):
output_tokens = usage.output_tokens

total_tokens = input_tokens + output_tokens
record_token_usage(span, input_tokens, output_tokens, total_tokens)

record_token_usage(
span,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)


def _get_responses(content):
Expand Down Expand Up @@ -126,7 +132,12 @@ def _add_ai_data_to_span(
[{"type": "text", "text": complete_message}],
)
total_tokens = input_tokens + output_tokens
record_token_usage(span, input_tokens, output_tokens, total_tokens)
record_token_usage(
span,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
span.set_data(SPANDATA.AI_STREAMING, True)


Expand Down
10 changes: 5 additions & 5 deletions sentry_sdk/integrations/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ def collect_chat_response_fields(span, res, include_pii):
if hasattr(res.meta, "billed_units"):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
completion_tokens=res.meta.billed_units.output_tokens,
input_tokens=res.meta.billed_units.input_tokens,
output_tokens=res.meta.billed_units.output_tokens,
)
elif hasattr(res.meta, "tokens"):
record_token_usage(
span,
prompt_tokens=res.meta.tokens.input_tokens,
completion_tokens=res.meta.tokens.output_tokens,
input_tokens=res.meta.tokens.input_tokens,
output_tokens=res.meta.tokens.output_tokens,
)

if hasattr(res.meta, "warnings"):
Expand Down Expand Up @@ -262,7 +262,7 @@ def new_embed(*args, **kwargs):
):
record_token_usage(
span,
prompt_tokens=res.meta.billed_units.input_tokens,
input_tokens=res.meta.billed_units.input_tokens,
total_tokens=res.meta.billed_units.input_tokens,
)
return res
Expand Down
10 changes: 8 additions & 2 deletions sentry_sdk/integrations/huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def new_text_generation(*args, **kwargs):
[res.generated_text],
)
if res.details is not None and res.details.generated_tokens > 0:
record_token_usage(span, total_tokens=res.details.generated_tokens)
record_token_usage(
span,
total_tokens=res.details.generated_tokens,
)
span.__exit__(None, None, None)
return res

Expand Down Expand Up @@ -145,7 +148,10 @@ def new_details_iterator():
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
if tokens_used > 0:
record_token_usage(span, total_tokens=tokens_used)
record_token_usage(
span,
total_tokens=tokens_used,
)
span.__exit__(None, None, None)

return new_details_iterator()
Expand Down
10 changes: 5 additions & 5 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,15 @@ def on_llm_end(self, response, *, run_id, **kwargs):
if token_usage:
record_token_usage(
span_data.span,
token_usage.get("prompt_tokens"),
token_usage.get("completion_tokens"),
token_usage.get("total_tokens"),
input_tokens=token_usage.get("prompt_tokens"),
output_tokens=token_usage.get("completion_tokens"),
total_tokens=token_usage.get("total_tokens"),
)
else:
record_token_usage(
span_data.span,
span_data.num_prompt_tokens,
span_data.num_completion_tokens,
input_tokens=span_data.num_prompt_tokens,
output_tokens=span_data.num_completion_tokens,
)

self._exit_span(span_data, run_id)
Expand Down
105 changes: 67 additions & 38 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,48 +70,75 @@ def _capture_exception(exc):
sentry_sdk.capture_event(event, hint=hint)


def _calculate_chat_completion_usage(
def _get_usage(usage, names):
# type: (Any, List[str]) -> int
for name in names:
if hasattr(usage, name) and isinstance(getattr(usage, name), int):
return getattr(usage, name)
return 0


def _calculate_token_usage(
messages, response, span, streaming_message_responses, count_tokens
):
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
completion_tokens = 0 # type: Optional[int]
prompt_tokens = 0 # type: Optional[int]
input_tokens = 0 # type: Optional[int]
input_tokens_cached = 0 # type: Optional[int]
output_tokens = 0 # type: Optional[int]
output_tokens_reasoning = 0 # type: Optional[int]
total_tokens = 0 # type: Optional[int]

if hasattr(response, "usage"):
if hasattr(response.usage, "completion_tokens") and isinstance(
response.usage.completion_tokens, int
):
completion_tokens = response.usage.completion_tokens
if hasattr(response.usage, "prompt_tokens") and isinstance(
response.usage.prompt_tokens, int
):
prompt_tokens = response.usage.prompt_tokens
if hasattr(response.usage, "total_tokens") and isinstance(
response.usage.total_tokens, int
):
total_tokens = response.usage.total_tokens
input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
if hasattr(response.usage, "input_tokens_details"):
input_tokens_cached = _get_usage(
response.usage.input_tokens_details, ["cached_tokens"]
)

output_tokens = _get_usage(
response.usage, ["output_tokens", "completion_tokens"]
)
if hasattr(response.usage, "output_tokens_details"):
output_tokens_reasoning = _get_usage(
response.usage.output_tokens_details, ["reasoning_tokens"]
)

if prompt_tokens == 0:
total_tokens = _get_usage(response.usage, ["total_tokens"])

# Manually count tokens
# TODO: check for responses API
if input_tokens == 0:
for message in messages:
if "content" in message:
prompt_tokens += count_tokens(message["content"])
input_tokens += count_tokens(message["content"])

if completion_tokens == 0:
# TODO: check for responses API
if output_tokens == 0:
if streaming_message_responses is not None:
for message in streaming_message_responses:
completion_tokens += count_tokens(message)
output_tokens += count_tokens(message)
elif hasattr(response, "choices"):
for choice in response.choices:
if hasattr(choice, "message"):
completion_tokens += count_tokens(choice.message)

if prompt_tokens == 0:
prompt_tokens = None
if completion_tokens == 0:
completion_tokens = None
if total_tokens == 0:
total_tokens = None
record_token_usage(span, prompt_tokens, completion_tokens, total_tokens)
output_tokens += count_tokens(choice.message)

# Do not set token data if it is 0
input_tokens = None if input_tokens == 0 else input_tokens
input_tokens_cached = None if input_tokens_cached == 0 else input_tokens_cached
output_tokens = None if output_tokens == 0 else output_tokens
output_tokens_reasoning = (
None if output_tokens_reasoning == 0 else output_tokens_reasoning
)
total_tokens = None if total_tokens == 0 else total_tokens

record_token_usage(
span,
input_tokens=input_tokens,
input_tokens_cached=input_tokens_cached,
output_tokens=output_tokens,
output_tokens_reasoning=output_tokens_reasoning,
total_tokens=total_tokens,
)


def _new_chat_completion_common(f, *args, **kwargs):
Expand Down Expand Up @@ -158,9 +185,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
SPANDATA.AI_RESPONSES,
list(map(lambda x: x.message, res.choices)),
)
_calculate_chat_completion_usage(
messages, res, span, None, integration.count_tokens
)
_calculate_token_usage(messages, res, span, None, integration.count_tokens)
span.__exit__(None, None, None)
elif hasattr(res, "_iterator"):
data_buf: list[list[str]] = [] # one for each choice
Expand Down Expand Up @@ -191,7 +216,7 @@ def new_iterator():
set_data_normalized(
span, SPANDATA.AI_RESPONSES, all_responses
)
_calculate_chat_completion_usage(
_calculate_token_usage(
messages,
res,
span,
Expand Down Expand Up @@ -224,7 +249,7 @@ async def new_iterator_async():
set_data_normalized(
span, SPANDATA.AI_RESPONSES, all_responses
)
_calculate_chat_completion_usage(
_calculate_token_usage(
messages,
res,
span,
Expand Down Expand Up @@ -341,22 +366,26 @@ def _new_embeddings_create_common(f, *args, **kwargs):

response = yield f, args, kwargs

prompt_tokens = 0
input_tokens = 0
total_tokens = 0
if hasattr(response, "usage"):
if hasattr(response.usage, "prompt_tokens") and isinstance(
response.usage.prompt_tokens, int
):
prompt_tokens = response.usage.prompt_tokens
input_tokens = response.usage.prompt_tokens
if hasattr(response.usage, "total_tokens") and isinstance(
response.usage.total_tokens, int
):
total_tokens = response.usage.total_tokens

if prompt_tokens == 0:
prompt_tokens = integration.count_tokens(kwargs["input"] or "")
if input_tokens == 0:
input_tokens = integration.count_tokens(kwargs["input"] or "")

record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)
record_token_usage(
span,
input_tokens=input_tokens,
total_tokens=total_tokens or input_tokens,
)

return response

Expand Down
Loading
Loading