Skip to content

Commit 11d68ee

Browse files
fix(google-genai): Token reporting (#5404)
Stop double accumulating tokens, as Gemini returns cumulative tokens in its streaming chunks. Takes an element-wise maximum to correctly count tokens when receiving a streaming response. The maximum accounts for that fact that some token types may be missing in intermediate chunks.
1 parent 78c6011 commit 11d68ee

File tree

2 files changed

+38
-42
lines changed

2 files changed

+38
-42
lines changed

sentry_sdk/integrations/google_genai/streaming.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
1-
from typing import (
2-
TYPE_CHECKING,
3-
Any,
4-
List,
5-
TypedDict,
6-
Optional,
7-
)
1+
from typing import TYPE_CHECKING, Any, List, TypedDict, Optional, Union
82

93
from sentry_sdk.ai.utils import set_data_normalized
104
from sentry_sdk.consts import SPANDATA
@@ -31,7 +25,21 @@ class AccumulatedResponse(TypedDict):
3125
text: str
3226
finish_reasons: "List[str]"
3327
tool_calls: "List[dict[str, Any]]"
34-
usage_metadata: "UsageData"
28+
usage_metadata: "Optional[UsageData]"
29+
30+
31+
def element_wise_usage_max(self: "UsageData", other: "UsageData") -> "UsageData":
32+
return UsageData(
33+
input_tokens=max(self["input_tokens"], other["input_tokens"]),
34+
output_tokens=max(self["output_tokens"], other["output_tokens"]),
35+
input_tokens_cached=max(
36+
self["input_tokens_cached"], other["input_tokens_cached"]
37+
),
38+
output_tokens_reasoning=max(
39+
self["output_tokens_reasoning"], other["output_tokens_reasoning"]
40+
),
41+
total_tokens=max(self["total_tokens"], other["total_tokens"]),
42+
)
3543

3644

3745
def accumulate_streaming_response(
@@ -41,11 +49,7 @@ def accumulate_streaming_response(
4149
accumulated_text = []
4250
finish_reasons = []
4351
tool_calls = []
44-
total_input_tokens = 0
45-
total_output_tokens = 0
46-
total_tokens = 0
47-
total_cached_tokens = 0
48-
total_reasoning_tokens = 0
52+
usage_data = None
4953
response_id = None
5054
model = None
5155

@@ -68,25 +72,21 @@ def accumulate_streaming_response(
6872
if extracted_tool_calls:
6973
tool_calls.extend(extracted_tool_calls)
7074

71-
# Accumulate token usage
72-
extracted_usage_data = extract_usage_data(chunk)
73-
total_input_tokens += extracted_usage_data["input_tokens"]
74-
total_output_tokens += extracted_usage_data["output_tokens"]
75-
total_cached_tokens += extracted_usage_data["input_tokens_cached"]
76-
total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"]
77-
total_tokens += extracted_usage_data["total_tokens"]
75+
# Use last possible chunk, in case of interruption, and
76+
# gracefully handle missing intermediate tokens by taking maximum
77+
# with previous token reporting.
78+
chunk_usage_data = extract_usage_data(chunk)
79+
usage_data = (
80+
chunk_usage_data
81+
if usage_data is None
82+
else element_wise_usage_max(usage_data, chunk_usage_data)
83+
)
7884

7985
accumulated_response = AccumulatedResponse(
8086
text="".join(accumulated_text),
8187
finish_reasons=finish_reasons,
8288
tool_calls=tool_calls,
83-
usage_metadata=UsageData(
84-
input_tokens=total_input_tokens,
85-
output_tokens=total_output_tokens,
86-
input_tokens_cached=total_cached_tokens,
87-
output_tokens_reasoning=total_reasoning_tokens,
88-
total_tokens=total_tokens,
89-
),
89+
usage_metadata=usage_data,
9090
id=response_id,
9191
model=model,
9292
)
@@ -126,6 +126,9 @@ def set_span_data_for_streaming_response(
126126
if accumulated_response.get("model"):
127127
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])
128128

129+
if accumulated_response["usage_metadata"] is None:
130+
return
131+
129132
if accumulated_response["usage_metadata"]["input_tokens"]:
130133
span.set_data(
131134
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,

tests/integrations/google_genai/test_google_genai.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,13 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
452452
"usageMetadata": {
453453
"promptTokenCount": 10,
454454
"candidatesTokenCount": 2,
455-
"totalTokenCount": 12, # Not set in intermediate chunks
455+
"totalTokenCount": 12,
456456
},
457457
"responseId": "response-id-stream-123",
458458
"modelVersion": "gemini-1.5-flash",
459459
}
460460

461-
# Chunk 2: Second part of text with more usage metadata
461+
# Chunk 2: Second part of text with intermediate usage metadata
462462
chunk2_json = {
463463
"candidates": [
464464
{
@@ -545,25 +545,18 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
545545
assert chat_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP"
546546
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == "STOP"
547547

548-
# Verify token counts - should reflect accumulated values
549-
# Input tokens: max of all chunks = 10
550-
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
551-
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
548+
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
549+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
552550

553-
# Output tokens: candidates (2 + 3 + 7 = 12) + reasoning (3) = 15
554-
# Note: output_tokens includes both candidates and reasoning tokens
555-
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
556-
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
551+
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
552+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 10
557553

558-
# Total tokens: from the last chunk
559-
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
560-
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
554+
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25
555+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 25
561556

562-
# Cached tokens: max of all chunks = 5
563557
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5
564558
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5
565559

566-
# Reasoning tokens: sum of thoughts_token_count = 3
567560
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3
568561
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING] == 3
569562

0 commit comments

Comments
 (0)