1- from typing import (
2- TYPE_CHECKING ,
3- Any ,
4- List ,
5- TypedDict ,
6- Optional ,
7- )
1+ from typing import TYPE_CHECKING , Any , List , TypedDict , Optional , Union
82
93from sentry_sdk .ai .utils import set_data_normalized
104from sentry_sdk .consts import SPANDATA
@@ -31,7 +25,21 @@ class AccumulatedResponse(TypedDict):
3125 text : str
3226 finish_reasons : "List[str]"
3327 tool_calls : "List[dict[str, Any]]"
34- usage_metadata : "UsageData"
28+ usage_metadata : "Optional[UsageData]"
29+
30+
31+ def element_wise_usage_max (self : "UsageData" , other : "UsageData" ) -> "UsageData" :
32+ return UsageData (
33+ input_tokens = max (self ["input_tokens" ], other ["input_tokens" ]),
34+ output_tokens = max (self ["output_tokens" ], other ["output_tokens" ]),
35+ input_tokens_cached = max (
36+ self ["input_tokens_cached" ], other ["input_tokens_cached" ]
37+ ),
38+ output_tokens_reasoning = max (
39+ self ["output_tokens_reasoning" ], other ["output_tokens_reasoning" ]
40+ ),
41+ total_tokens = max (self ["total_tokens" ], other ["total_tokens" ]),
42+ )
3543
3644
3745def accumulate_streaming_response (
@@ -41,11 +49,7 @@ def accumulate_streaming_response(
4149 accumulated_text = []
4250 finish_reasons = []
4351 tool_calls = []
44- total_input_tokens = 0
45- total_output_tokens = 0
46- total_tokens = 0
47- total_cached_tokens = 0
48- total_reasoning_tokens = 0
52+ usage_data = None
4953 response_id = None
5054 model = None
5155
@@ -68,25 +72,21 @@ def accumulate_streaming_response(
6872 if extracted_tool_calls :
6973 tool_calls .extend (extracted_tool_calls )
7074
71- # Accumulate token usage
72- extracted_usage_data = extract_usage_data (chunk )
73- total_input_tokens += extracted_usage_data ["input_tokens" ]
74- total_output_tokens += extracted_usage_data ["output_tokens" ]
75- total_cached_tokens += extracted_usage_data ["input_tokens_cached" ]
76- total_reasoning_tokens += extracted_usage_data ["output_tokens_reasoning" ]
77- total_tokens += extracted_usage_data ["total_tokens" ]
75+ # Use last possible chunk, in case of interruption, and
76+ # gracefully handle missing intermediate tokens by taking maximum
77+ # with previous token reporting.
78+ chunk_usage_data = extract_usage_data (chunk )
79+ usage_data = (
80+ chunk_usage_data
81+ if usage_data is None
82+ else element_wise_usage_max (usage_data , chunk_usage_data )
83+ )
7884
7985 accumulated_response = AccumulatedResponse (
8086 text = "" .join (accumulated_text ),
8187 finish_reasons = finish_reasons ,
8288 tool_calls = tool_calls ,
83- usage_metadata = UsageData (
84- input_tokens = total_input_tokens ,
85- output_tokens = total_output_tokens ,
86- input_tokens_cached = total_cached_tokens ,
87- output_tokens_reasoning = total_reasoning_tokens ,
88- total_tokens = total_tokens ,
89- ),
89+ usage_metadata = usage_data ,
9090 id = response_id ,
9191 model = model ,
9292 )
@@ -126,6 +126,9 @@ def set_span_data_for_streaming_response(
126126 if accumulated_response .get ("model" ):
127127 span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , accumulated_response ["model" ])
128128
129+ if accumulated_response ["usage_metadata" ] is None :
130+ return
131+
129132 if accumulated_response ["usage_metadata" ]["input_tokens" ]:
130133 span .set_data (
131134 SPANDATA .GEN_AI_USAGE_INPUT_TOKENS ,
0 commit comments