Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 22 additions & 42 deletions litellm/llms/base_llm/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@
type_to_response_format_param,
)

_function_types = (
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
property,
)

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

Expand All @@ -63,19 +71,13 @@ def __init__(
if request:
self.request = request
else:
self.request = httpx.Request(
method="POST", url="https://docs.litellm.ai/docs"
)
self.request = httpx.Request(method="POST", url="https://docs.litellm.ai/docs")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
self.response = httpx.Response(status_code=status_code, request=self.request)
self.body = body
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
super().__init__(self.message) # Call the base class constructor with the parameters it needs


class BaseConfig(ABC):
Expand All @@ -84,22 +86,15 @@ def __init__(self):

@classmethod
def get_config(cls):
cls_vars = vars(cls)
return {
k: v
for k, v in cls.__dict__.items()
for k, v in cls_vars.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not k.startswith("_is_base_class")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
property,
),
)
and not callable(v)
and not isinstance(v, _function_types)
and v is not None
}

Expand All @@ -118,14 +113,9 @@ def is_max_tokens_in_request(self, non_default_params: dict) -> bool:
"""
OpenAI spec allows max_tokens or max_completion_tokens to be specified.
"""
return (
"max_tokens" in non_default_params
or "max_completion_tokens" in non_default_params
)
return "max_tokens" in non_default_params or "max_completion_tokens" in non_default_params

def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
):
def update_optional_params_with_thinking_tokens(self, non_default_params: dict, optional_params: dict):
"""
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.

Expand All @@ -135,13 +125,9 @@ def update_optional_params_with_thinking_tokens(
"""
is_thinking_enabled = self.is_thinking_enabled(optional_params)
if is_thinking_enabled and "max_tokens" not in non_default_params:
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
"budget_tokens", None
)
thinking_token_budget = cast(dict, optional_params["thinking"]).get("budget_tokens", None)
if thinking_token_budget is not None:
optional_params["max_tokens"] = (
thinking_token_budget + DEFAULT_MAX_TOKENS
)
optional_params["max_tokens"] = thinking_token_budget + DEFAULT_MAX_TOKENS

def should_fake_stream(
self,
Expand Down Expand Up @@ -188,9 +174,7 @@ def should_retry_llm_api_inside_llm_translation_on_http_error(
"""
return False

def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
def transform_request_on_unprocessable_entity_error(self, e: httpx.HTTPStatusError, request_data: dict) -> dict:
"""
Transform the request data on UnprocessableEntityError
"""
Expand Down Expand Up @@ -237,16 +221,12 @@ def _add_response_format_to_tools(
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=RESPONSE_FORMAT_TOOL_NAME
),
function=ChatCompletionToolChoiceFunctionParam(name=RESPONSE_FORMAT_TOOL_NAME),
)

_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema
),
function=ChatCompletionToolParamFunctionChunk(name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema),
)

optional_params.setdefault("tools", [])
Expand Down
16 changes: 4 additions & 12 deletions litellm/llms/petals/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ class PetalsConfig(BaseConfig):
"""

max_length: Optional[int] = None
max_new_tokens: Optional[
int
] = litellm.max_tokens # petals requires max tokens to be set
max_new_tokens: Optional[int] = litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
Expand All @@ -49,9 +47,7 @@ class PetalsConfig(BaseConfig):
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
max_new_tokens: Optional[int] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -67,12 +63,8 @@ def __init__(
def get_config(cls):
return super().get_config()

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return PetalsError(
status_code=status_code, message=error_message, headers=headers
)
def get_error_class(self, error_message: str, status_code: int, headers: Union[dict, Headers]) -> BaseLLMException:
return PetalsError(status_code=status_code, message=error_message, headers=headers)

def get_supported_openai_params(self, model: str) -> List:
return ["max_tokens", "temperature", "top_p", "stream"]
Expand Down