Skip to content

Fix/gemini prompt caching usage feedback #11095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
"""
from litellm.types.utils import (
CompletionTokensDetails,
CompletionTokensDetailsWrapper,
PromptTokensDetailsWrapper,
Usage,
)
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
not hasattr(combined, "completion_tokens_details")
or not combined.completion_tokens_details
):
combined.completion_tokens_details = CompletionTokensDetails()
combined.completion_tokens_details = CompletionTokensDetailsWrapper()

# Check what keys exist in the model's completion_tokens_details
for attr in dir(usage.completion_tokens_details):
Expand Down
2 changes: 2 additions & 0 deletions litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,14 @@ async def async_log_event(
):
# Method definition
try:
cache_hit = kwargs.get("cache_hit", False)
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
cache_hit
)
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
),
)

if "prompt_tokens_details" in response_object["usage"] and response_object["usage"]["prompt_tokens_details"] is not None and model_response_object.usage is not None:
model_response_object.usage.prompt_tokens_details = response_object["usage"]["prompt_tokens_details"]

if "id" in response_object:
model_response_object.id = response_object["id"]

Expand Down
86 changes: 53 additions & 33 deletions litellm/litellm_core_utils/streaming_chunk_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import Any, Dict, List, Optional, Union, cast

from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.types.llms.openai import (
ChatCompletionAssistantContentValue,
ChatCompletionAudioDelta,
Expand Down Expand Up @@ -255,8 +256,8 @@ def _usage_chunk_calculation_helper(self, usage_chunk: Usage) -> dict:
## anthropic prompt caching information ##
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None

if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
Expand All @@ -277,7 +278,7 @@ def _usage_chunk_calculation_helper(self, usage_chunk: Usage) -> dict:
completion_tokens_details = usage_chunk.completion_tokens_details
if hasattr(usage_chunk, "prompt_tokens_details"):
if isinstance(usage_chunk.prompt_tokens_details, dict):
prompt_tokens_details = PromptTokensDetails(
prompt_tokens_details = PromptTokensDetailsWrapper(
**usage_chunk.prompt_tokens_details
)
elif isinstance(usage_chunk.prompt_tokens_details, PromptTokensDetails):
Expand Down Expand Up @@ -306,6 +307,45 @@ def count_reasoning_tokens(self, response: ModelResponse) -> int:

return reasoning_tokens

def _set_token_details(
self,
returned_usage: Usage,
cache_creation_input_tokens: Optional[int],
cache_read_input_tokens: Optional[int],
completion_tokens_details: Optional[CompletionTokensDetails],
prompt_tokens_details: Optional[PromptTokensDetailsWrapper],
reasoning_tokens: Optional[int],
) -> None:
"""
Helper method to set token details on the usage object
"""
if cache_creation_input_tokens is not None:
returned_usage._cache_creation_input_tokens = cache_creation_input_tokens
returned_usage.cache_creation_input_tokens = cache_creation_input_tokens
if cache_read_input_tokens is not None:
returned_usage._cache_read_input_tokens = cache_read_input_tokens
returned_usage.cache_read_input_tokens = cache_read_input_tokens

if completion_tokens_details is not None:
if isinstance(completion_tokens_details, CompletionTokensDetails) and not isinstance(completion_tokens_details, CompletionTokensDetailsWrapper):
returned_usage.completion_tokens_details = CompletionTokensDetailsWrapper(**completion_tokens_details.model_dump())
else:
returned_usage.completion_tokens_details = completion_tokens_details

if reasoning_tokens is not None:
if returned_usage.completion_tokens_details is None:
returned_usage.completion_tokens_details = (
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
)
elif (
returned_usage.completion_tokens_details is not None
and returned_usage.completion_tokens_details.reasoning_tokens is None
):
returned_usage.completion_tokens_details.reasoning_tokens = reasoning_tokens

if prompt_tokens_details is not None:
returned_usage.prompt_tokens_details = prompt_tokens_details

def calculate_usage(
self,
chunks: List[Union[Dict[str, Any], ModelResponse]],
Expand All @@ -325,7 +365,7 @@ def calculate_usage(
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
completion_tokens_details: Optional[CompletionTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetails] = None
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
Expand Down Expand Up @@ -385,35 +425,15 @@ def calculate_usage(
returned_usage.prompt_tokens + returned_usage.completion_tokens
)

if cache_creation_input_tokens is not None:
returned_usage._cache_creation_input_tokens = cache_creation_input_tokens
setattr(
returned_usage,
"cache_creation_input_tokens",
cache_creation_input_tokens,
) # for anthropic
if cache_read_input_tokens is not None:
returned_usage._cache_read_input_tokens = cache_read_input_tokens
setattr(
returned_usage, "cache_read_input_tokens", cache_read_input_tokens
) # for anthropic
if completion_tokens_details is not None:
returned_usage.completion_tokens_details = completion_tokens_details

if reasoning_tokens is not None:
if returned_usage.completion_tokens_details is None:
returned_usage.completion_tokens_details = (
CompletionTokensDetailsWrapper(reasoning_tokens=reasoning_tokens)
)
elif (
returned_usage.completion_tokens_details is not None
and returned_usage.completion_tokens_details.reasoning_tokens is None
):
returned_usage.completion_tokens_details.reasoning_tokens = (
reasoning_tokens
)
if prompt_tokens_details is not None:
returned_usage.prompt_tokens_details = prompt_tokens_details
# Set token details using the helper method
self._set_token_details(
returned_usage,
cache_creation_input_tokens,
cache_read_input_tokens,
completion_tokens_details,
prompt_tokens_details,
reasoning_tokens,
)

return returned_usage

Expand Down
127 changes: 121 additions & 6 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,10 +982,51 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
]

if anthropic_response_obj["usage"] is not None:
# Extract token details from usage if available
usage_data = anthropic_response_obj["usage"]

# Initialize token details
audio_tokens = 0
text_tokens = 0
image_tokens = 0
reasoning_tokens = None
response_tokens_details = None

# Extract reasoning tokens if available
completion_tokens_details = usage_data.get("completion_tokens_details")
if completion_tokens_details is not None and "reasoning_tokens" in completion_tokens_details:
reasoning_tokens = completion_tokens_details["reasoning_tokens"]

# Extract prompt tokens details if available
prompt_tokens_details_dict = usage_data.get("prompt_tokens_details")
if prompt_tokens_details_dict is not None:
if "text_tokens" in prompt_tokens_details_dict:
text_tokens = prompt_tokens_details_dict["text_tokens"]
if "audio_tokens" in prompt_tokens_details_dict:
audio_tokens = prompt_tokens_details_dict["audio_tokens"]
if "image_tokens" in prompt_tokens_details_dict:
image_tokens = prompt_tokens_details_dict["image_tokens"]

cached_tokens = text_tokens + audio_tokens + image_tokens
prompt_tokens_details = litellm.types.utils.PromptTokensDetailsWrapper(
cached_tokens=cached_tokens,
audio_tokens=audio_tokens,
text_tokens=text_tokens,
image_tokens=image_tokens
)

# Create usage object with all details
setattr(
model_response,
"usage",
litellm.Usage(**anthropic_response_obj["usage"]),
litellm.Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
prompt_tokens_details=prompt_tokens_details,
reasoning_tokens=reasoning_tokens,
completion_tokens_details=response_tokens_details
),
)

if (
Expand Down Expand Up @@ -1113,6 +1154,64 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
self.received_finish_reason = chunk.candidates[ # type: ignore
0
].finish_reason.name

# Extract usage information if available
if hasattr(chunk, "usageMetadata") and chunk.usageMetadata is not None:
usage_metadata = chunk.usageMetadata

cached_tokens = 0
audio_tokens = 0
text_tokens = 0
image_tokens = 0

if hasattr(usage_metadata, "cachedContentTokenCount"):
cached_tokens = usage_metadata.cachedContentTokenCount

# Extract text, audio, and image tokens from promptTokensDetails if available
if hasattr(usage_metadata, "promptTokensDetails"):
for detail in usage_metadata.promptTokensDetails:
if hasattr(detail, "modality") and detail.modality == "AUDIO":
audio_tokens = detail.tokenCount
elif hasattr(detail, "modality") and detail.modality == "TEXT":
text_tokens = detail.tokenCount
elif hasattr(detail, "modality") and detail.modality == "IMAGE":
image_tokens = detail.tokenCount

# Create prompt_tokens_details with all token types
prompt_tokens_details = litellm.types.utils.PromptTokensDetailsWrapper(
cached_tokens=cached_tokens,
audio_tokens=audio_tokens,
text_tokens=text_tokens,
image_tokens=image_tokens
)

# Extract response tokens details if available
response_tokens_details = None
if hasattr(usage_metadata, "responseTokensDetails"):
response_tokens_details = litellm.types.utils.CompletionTokensDetailsWrapper()
for detail in usage_metadata.responseTokensDetails:
if detail.modality == "TEXT":
response_tokens_details.text_tokens = detail.tokenCount
elif detail.modality == "AUDIO":
response_tokens_details.audio_tokens = detail.tokenCount

# Extract reasoning tokens if available
reasoning_tokens = None
if hasattr(usage_metadata, "thoughtsTokenCount"):
reasoning_tokens = usage_metadata.thoughtsTokenCount

setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=getattr(usage_metadata, "promptTokenCount", 0),
completion_tokens=getattr(usage_metadata, "candidatesTokenCount", 0),
total_tokens=getattr(usage_metadata, "totalTokenCount", 0),
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=response_tokens_details,
reasoning_tokens=reasoning_tokens
),
)
except Exception:
if chunk.candidates[0].finish_reason.name == "SAFETY": # type: ignore
raise Exception(
Expand Down Expand Up @@ -1221,6 +1320,8 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
self.system_fingerprint = chunk.system_fingerprint
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if hasattr(chunk, "usage") and chunk.usage is not None:
setattr(model_response, "usage", chunk.usage)
else: # openai / azure chat model
if self.custom_llm_provider == "azure":
if isinstance(chunk, BaseModel) and hasattr(chunk, "model"):
Expand Down Expand Up @@ -1881,17 +1982,31 @@ def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
"""Assume most recent usage chunk has total usage uptil then."""
prompt_tokens: int = 0
completion_tokens: int = 0
prompt_tokens_details = None
completion_tokens_details = None
reasoning_tokens = None

for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
usage = chunk.get("usage")
if usage is not None:
if "prompt_tokens" in usage:
prompt_tokens = usage.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage:
completion_tokens = usage.get("completion_tokens", 0) or 0
if "prompt_tokens_details" in usage:
prompt_tokens_details = usage.get("prompt_tokens_details")
if "completion_tokens_details" in usage:
completion_tokens_details = usage.get("completion_tokens_details")
if "reasoning_tokens" in usage:
reasoning_tokens = usage.get("reasoning_tokens")

returned_usage_chunk = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
reasoning_tokens=reasoning_tokens,
)

return returned_usage_chunk
Expand Down
Loading
Loading