Skip to content

Codestral - return litellm latency overhead on /v1/completions + Add '__contains__' support for ChatCompletionDeltaToolCall #10879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions litellm/llms/codestral/completion/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
Expand Down Expand Up @@ -333,6 +334,7 @@ def completion(
encoding=encoding,
)

@track_llm_api_timing()
async def async_completion(
self,
model: str,
Expand Down Expand Up @@ -382,6 +384,7 @@ async def async_completion(
encoding=encoding,
)

@track_llm_api_timing()
async def async_streaming(
self,
model: str,
Expand Down
6 changes: 6 additions & 0 deletions litellm/llms/codestral/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:

original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
if len(_choices) == 0:
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
_choice = _choices[0]
text = _choice.get("delta", {}).get("content", "")

Expand Down
2 changes: 1 addition & 1 deletion litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ litellm_settings:
cache: true
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
alerting: ["slack"]
alerting: ["slack"]
2 changes: 2 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ async def common_processing_pre_call_logic(
"acreate_batch",
"aretrieve_batch",
"afile_content",
"atext_completion",
"acreate_fine_tuning_job",
"acancel_fine_tuning_job",
"alist_fine_tuning_jobs",
Expand Down Expand Up @@ -322,6 +323,7 @@ async def base_process_llm_request(
"_arealtime",
"aget_responses",
"adelete_responses",
"atext_completion",
"aimage_edit",
],
proxy_logging_obj: ProxyLogging,
Expand Down
118 changes: 13 additions & 105 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3643,117 +3643,25 @@ async def completion( # noqa: PLR0915
data = {}
try:
data = await _read_request_body(request=request)

data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data.get("model", None)
)
if user_model:
data["model"] = user_model

data = await add_litellm_data_to_request(
data=data,
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
return await base_llm_response_processor.base_process_llm_request(
request=request,
general_settings=general_settings,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)

# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base

### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]

### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)

### ROUTE THE REQUESTs ###
llm_call = await route_request(
data=data,
route_type="atext_completion",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=model,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)

# Await the llm_response task
response = await llm_call

hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""

### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)

verbose_proxy_logger.debug("final response: %s", response)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
hidden_params=hidden_params,
request_data=data,
)
selected_data_generator = select_data_generator(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)

return await create_streaming_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)

fastapi_response.headers.update(
ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=litellm_call_id,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)
await check_response_size_is_safe(response=response)
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
Expand Down
12 changes: 7 additions & 5 deletions litellm/router_utils/add_retry_fallback_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ def _add_headers_to_response(response: Any, headers: dict) -> Any:
)

if hidden_params is None:
hidden_params = {}
hidden_params_dict = {}
elif isinstance(hidden_params, HiddenParams):
hidden_params = hidden_params.model_dump()
hidden_params_dict = hidden_params.model_dump()
else:
hidden_params_dict = hidden_params

hidden_params.setdefault("additional_headers", {})
hidden_params["additional_headers"].update(headers)
hidden_params_dict.setdefault("additional_headers", {})
hidden_params_dict["additional_headers"].update(headers)

setattr(response, "_hidden_params", hidden_params)
setattr(response, "_hidden_params", hidden_params_dict)
return response


Expand Down
23 changes: 23 additions & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,28 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
type: Optional[str] = None
index: int

def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)

def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)

def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)

def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)


class HiddenParams(OpenAIObject):
original_response: Optional[Union[str, Any]] = None
model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call
_response_ms: Optional[float] = None

model_config = ConfigDict(extra="allow", protected_namespaces=())

Expand All @@ -471,6 +488,12 @@ def json(self, **kwargs): # type: ignore
# if using pydantic v1
return self.dict()

def model_dump(self, **kwargs):
# Override model_dump to include private attributes
data = super().model_dump(**kwargs)
data["_response_ms"] = self._response_ms
return data


class ChatCompletionMessageToolCall(OpenAIObject):
def __init__(
Expand Down
35 changes: 35 additions & 0 deletions tests/litellm/types/test_types_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import os
import sys
from typing import Optional
from unittest.mock import AsyncMock, patch

import pytest

sys.path.insert(0, os.path.abspath("../.."))
import json

from litellm.types.utils import HiddenParams


def test_hidden_params_response_ms():
hidden_params = HiddenParams()
setattr(hidden_params, "_response_ms", 100)
hidden_params_dict = hidden_params.model_dump()
assert hidden_params_dict.get("_response_ms") == 100


def test_chat_completion_delta_tool_call():
from litellm.types.utils import ChatCompletionDeltaToolCall, Function

tool = ChatCompletionDeltaToolCall(
id="call_m87w",
function=Function(
arguments='{"location": "San Francisco", "unit": "imperial"}',
name="get_current_weather",
),
type="function",
index=0,
)

assert "function" in tool
Loading