Skip to content

Commit 64096ae

Browse files
Codestral - return litellm latency overhead on /v1/completions + Add '__contains__' support for ChatCompletionDeltaToolCall (#10879)
* feat(codestral/completion): return litellm latency overhead for codestral enables easier debugging of latency issues * fix(types/utils.py): support _response_ms on hidden params model dump Fixes issue where 'x-litellm-overhead-duration-ms' wasn't being returned on text c ompletion calls * fix(types/utils.py): add '__contains__' support for chatcompletiondeltatool call Fixes #7099 * fix: fix linting error * fix: fix linting error
1 parent acaa802 commit 64096ae

File tree

8 files changed

+90
-111
lines changed

8 files changed

+90
-111
lines changed

litellm/llms/codestral/completion/handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import litellm
1111
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
12+
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
1213
from litellm.litellm_core_utils.prompt_templates.factory import (
1314
custom_prompt,
1415
prompt_factory,
@@ -333,6 +334,7 @@ def completion(
333334
encoding=encoding,
334335
)
335336

337+
@track_llm_api_timing()
336338
async def async_completion(
337339
self,
338340
model: str,
@@ -382,6 +384,7 @@ async def async_completion(
382384
encoding=encoding,
383385
)
384386

387+
@track_llm_api_timing()
385388
async def async_streaming(
386389
self,
387390
model: str,

litellm/llms/codestral/completion/transformation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
104104

105105
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
106106
_choices = chunk_data_dict.get("choices", []) or []
107+
if len(_choices) == 0:
108+
return {
109+
"text": "",
110+
"is_finished": is_finished,
111+
"finish_reason": finish_reason,
112+
}
107113
_choice = _choices[0]
108114
text = _choice.get("delta", {}).get("content", "")
109115

litellm/proxy/_new_secret_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ litellm_settings:
88
cache: true
99
success_callback: ["langfuse"]
1010
failure_callback: ["langfuse"]
11-
alerting: ["slack"]
11+
alerting: ["slack"]

litellm/proxy/common_request_processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ async def common_processing_pre_call_logic(
241241
"acreate_batch",
242242
"aretrieve_batch",
243243
"afile_content",
244+
"atext_completion",
244245
"acreate_fine_tuning_job",
245246
"acancel_fine_tuning_job",
246247
"alist_fine_tuning_jobs",
@@ -322,6 +323,7 @@ async def base_process_llm_request(
322323
"_arealtime",
323324
"aget_responses",
324325
"adelete_responses",
326+
"atext_completion",
325327
"aimage_edit",
326328
],
327329
proxy_logging_obj: ProxyLogging,

litellm/proxy/proxy_server.py

Lines changed: 13 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -3643,117 +3643,25 @@ async def completion( # noqa: PLR0915
36433643
data = {}
36443644
try:
36453645
data = await _read_request_body(request=request)
3646-
3647-
data["model"] = (
3648-
general_settings.get("completion_model", None) # server default
3649-
or user_model # model name passed via cli args
3650-
or model # for azure deployments
3651-
or data.get("model", None)
3652-
)
3653-
if user_model:
3654-
data["model"] = user_model
3655-
3656-
data = await add_litellm_data_to_request(
3657-
data=data,
3646+
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
3647+
return await base_llm_response_processor.base_process_llm_request(
36583648
request=request,
3659-
general_settings=general_settings,
3649+
fastapi_response=fastapi_response,
36603650
user_api_key_dict=user_api_key_dict,
3661-
version=version,
3662-
proxy_config=proxy_config,
3663-
)
3664-
3665-
# override with user settings, these are params passed via cli
3666-
if user_temperature:
3667-
data["temperature"] = user_temperature
3668-
if user_request_timeout:
3669-
data["request_timeout"] = user_request_timeout
3670-
if user_max_tokens:
3671-
data["max_tokens"] = user_max_tokens
3672-
if user_api_base:
3673-
data["api_base"] = user_api_base
3674-
3675-
### MODEL ALIAS MAPPING ###
3676-
# check if model name in model alias map
3677-
# get the actual model name
3678-
if data["model"] in litellm.model_alias_map:
3679-
data["model"] = litellm.model_alias_map[data["model"]]
3680-
3681-
### CALL HOOKS ### - modify incoming data before calling the model
3682-
data = await proxy_logging_obj.pre_call_hook( # type: ignore
3683-
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
3684-
)
3685-
3686-
### ROUTE THE REQUESTs ###
3687-
llm_call = await route_request(
3688-
data=data,
36893651
route_type="atext_completion",
3652+
proxy_logging_obj=proxy_logging_obj,
36903653
llm_router=llm_router,
3654+
general_settings=general_settings,
3655+
proxy_config=proxy_config,
3656+
select_data_generator=select_data_generator,
3657+
model=model,
36913658
user_model=user_model,
3659+
user_temperature=user_temperature,
3660+
user_request_timeout=user_request_timeout,
3661+
user_max_tokens=user_max_tokens,
3662+
user_api_base=user_api_base,
3663+
version=version,
36923664
)
3693-
3694-
# Await the llm_response task
3695-
response = await llm_call
3696-
3697-
hidden_params = getattr(response, "_hidden_params", {}) or {}
3698-
model_id = hidden_params.get("model_id", None) or ""
3699-
cache_key = hidden_params.get("cache_key", None) or ""
3700-
api_base = hidden_params.get("api_base", None) or ""
3701-
response_cost = hidden_params.get("response_cost", None) or ""
3702-
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
3703-
3704-
### ALERTING ###
3705-
asyncio.create_task(
3706-
proxy_logging_obj.update_request_status(
3707-
litellm_call_id=data.get("litellm_call_id", ""), status="success"
3708-
)
3709-
)
3710-
3711-
verbose_proxy_logger.debug("final response: %s", response)
3712-
if (
3713-
"stream" in data and data["stream"] is True
3714-
): # use generate_responses to stream responses
3715-
custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
3716-
user_api_key_dict=user_api_key_dict,
3717-
call_id=litellm_call_id,
3718-
model_id=model_id,
3719-
cache_key=cache_key,
3720-
api_base=api_base,
3721-
version=version,
3722-
response_cost=response_cost,
3723-
hidden_params=hidden_params,
3724-
request_data=data,
3725-
)
3726-
selected_data_generator = select_data_generator(
3727-
response=response,
3728-
user_api_key_dict=user_api_key_dict,
3729-
request_data=data,
3730-
)
3731-
3732-
return await create_streaming_response(
3733-
generator=selected_data_generator,
3734-
media_type="text/event-stream",
3735-
headers=custom_headers,
3736-
)
3737-
### CALL HOOKS ### - modify outgoing data
3738-
response = await proxy_logging_obj.post_call_success_hook(
3739-
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
3740-
)
3741-
3742-
fastapi_response.headers.update(
3743-
ProxyBaseLLMRequestProcessing.get_custom_headers(
3744-
user_api_key_dict=user_api_key_dict,
3745-
call_id=litellm_call_id,
3746-
model_id=model_id,
3747-
cache_key=cache_key,
3748-
api_base=api_base,
3749-
version=version,
3750-
response_cost=response_cost,
3751-
request_data=data,
3752-
hidden_params=hidden_params,
3753-
)
3754-
)
3755-
await check_response_size_is_safe(response=response)
3756-
return response
37573665
except RejectedRequestError as e:
37583666
_data = e.request_data
37593667
await proxy_logging_obj.post_call_failure_hook(

litellm/router_utils/add_retry_fallback_headers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ def _add_headers_to_response(response: Any, headers: dict) -> Any:
1717
)
1818

1919
if hidden_params is None:
20-
hidden_params = {}
20+
hidden_params_dict = {}
2121
elif isinstance(hidden_params, HiddenParams):
22-
hidden_params = hidden_params.model_dump()
22+
hidden_params_dict = hidden_params.model_dump()
23+
else:
24+
hidden_params_dict = hidden_params
2325

24-
hidden_params.setdefault("additional_headers", {})
25-
hidden_params["additional_headers"].update(headers)
26+
hidden_params_dict.setdefault("additional_headers", {})
27+
hidden_params_dict["additional_headers"].update(headers)
2628

27-
setattr(response, "_hidden_params", hidden_params)
29+
setattr(response, "_hidden_params", hidden_params_dict)
2830
return response
2931

3032

litellm/types/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,28 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
444444
type: Optional[str] = None
445445
index: int
446446

447+
def __contains__(self, key):
448+
# Define custom behavior for the 'in' operator
449+
return hasattr(self, key)
450+
451+
def get(self, key, default=None):
452+
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
453+
return getattr(self, key, default)
454+
455+
def __getitem__(self, key):
456+
# Allow dictionary-style access to attributes
457+
return getattr(self, key)
458+
459+
def __setitem__(self, key, value):
460+
# Allow dictionary-style assignment of attributes
461+
setattr(self, key, value)
462+
447463

448464
class HiddenParams(OpenAIObject):
449465
original_response: Optional[Union[str, Any]] = None
450466
model_id: Optional[str] = None # used in Router for individual deployments
451467
api_base: Optional[str] = None # returns api base used for making completion call
468+
_response_ms: Optional[float] = None
452469

453470
model_config = ConfigDict(extra="allow", protected_namespaces=())
454471

@@ -471,6 +488,12 @@ def json(self, **kwargs): # type: ignore
471488
# if using pydantic v1
472489
return self.dict()
473490

491+
def model_dump(self, **kwargs):
492+
# Override model_dump to include private attributes
493+
data = super().model_dump(**kwargs)
494+
data["_response_ms"] = self._response_ms
495+
return data
496+
474497

475498
class ChatCompletionMessageToolCall(OpenAIObject):
476499
def __init__(
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import asyncio
2+
import os
3+
import sys
4+
from typing import Optional
5+
from unittest.mock import AsyncMock, patch
6+
7+
import pytest
8+
9+
sys.path.insert(0, os.path.abspath("../.."))
10+
import json
11+
12+
from litellm.types.utils import HiddenParams
13+
14+
15+
def test_hidden_params_response_ms():
16+
hidden_params = HiddenParams()
17+
setattr(hidden_params, "_response_ms", 100)
18+
hidden_params_dict = hidden_params.model_dump()
19+
assert hidden_params_dict.get("_response_ms") == 100
20+
21+
22+
def test_chat_completion_delta_tool_call():
23+
from litellm.types.utils import ChatCompletionDeltaToolCall, Function
24+
25+
tool = ChatCompletionDeltaToolCall(
26+
id="call_m87w",
27+
function=Function(
28+
arguments='{"location": "San Francisco", "unit": "imperial"}',
29+
name="get_current_weather",
30+
),
31+
type="function",
32+
index=0,
33+
)
34+
35+
assert "function" in tool

0 commit comments

Comments
 (0)