|
13 | 13 | from typing_extensions import Unpack, override |
14 | 14 |
|
15 | 15 | from ..types.content import ContentBlock, Messages |
| 16 | +from ..types.exceptions import ContextWindowOverflowException |
16 | 17 | from ..types.streaming import StreamEvent |
17 | 18 | from ..types.tools import ToolChoice, ToolSpec |
18 | 19 | from ._validation import validate_config_keys |
|
22 | 23 |
|
23 | 24 | T = TypeVar("T", bound=BaseModel) |
24 | 25 |
|
| 26 | +# See: https://github.com/BerriAI/litellm/blob/main/litellm/exceptions.py |
| 27 | +# The following are common substrings found in context window related errors |
| 28 | +# from various models proxied by LiteLLM. |
| 29 | +LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ |
| 30 | + "Context Window Error", |
| 31 | + "Context Window Exceeded", |
| 32 | + "ContextWindowExceeded", |
| 33 | + "Context window exceeded", |
| 34 | + "Input is too long", |
| 35 | + "ContextWindowExceededError", |
| 36 | +] |
| 37 | + |
25 | 38 |
|
26 | 39 | class LiteLLMModel(OpenAIModel): |
27 | 40 | """LiteLLM model provider implementation.""" |
@@ -56,6 +69,32 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: |
56 | 69 |
|
57 | 70 | logger.debug("config=<%s> | initializing", self.config) |
58 | 71 |
|
| 72 | + def _handle_context_window_overflow(self, e: Exception) -> None: |
| 73 | + """Handle context window overflow errors from LiteLLM. |
| 74 | +
|
| 75 | + Args: |
| 76 | + e: The exception to handle. |
| 77 | +
|
| 78 | + Raises: |
| 79 | + ContextWindowOverflowException: If the exception is a context window overflow error. |
| 80 | + """ |
| 81 | + # Prefer litellm-specific typed exception if exposed |
| 82 | + litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr( |
| 83 | + litellm, "ContextWindowExceeded", None |
| 84 | + ) |
| 85 | + if litellm_exc_type and isinstance(e, litellm_exc_type): |
| 86 | + logger.warning("litellm client raised context window overflow") |
| 87 | + raise ContextWindowOverflowException(e) from e |
| 88 | + |
| 89 | + # Fallback to substring checks similar to Bedrock handling |
| 90 | + error_message = str(e) |
| 91 | + if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES): |
| 92 | + logger.warning("litellm threw context window overflow error") |
| 93 | + raise ContextWindowOverflowException(e) from e |
| 94 | + |
| 95 | + # Not a context-window error — re-raise original |
| 96 | + raise e |
| 97 | + |
59 | 98 | @override |
60 | 99 | def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override] |
61 | 100 | """Update the LiteLLM model configuration with the provided arguments. |
@@ -135,7 +174,10 @@ async def stream( |
135 | 174 | logger.debug("request=<%s>", request) |
136 | 175 |
|
137 | 176 | logger.debug("invoking model") |
138 | | - response = await litellm.acompletion(**self.client_args, **request) |
| 177 | + try: |
| 178 | + response = await litellm.acompletion(**self.client_args, **request) |
| 179 | + except Exception as e: |
| 180 | + self._handle_context_window_overflow(e) |
139 | 181 |
|
140 | 182 | logger.debug("got response from model") |
141 | 183 | yield self.format_chunk({"chunk_type": "message_start"}) |
@@ -205,15 +247,23 @@ async def structured_output( |
205 | 247 | Yields: |
206 | 248 | Model events with the last being the structured output. |
207 | 249 | """ |
208 | | - if not supports_response_schema(self.get_config()["model_id"]): |
| 250 | + supports_schema = supports_response_schema(self.get_config()["model_id"]) |
| 251 | + |
| 252 | + # If the provider does not support response schemas, we cannot reliably parse structured output. |
| 253 | + # In that case we must not call the provider and must raise the documented ValueError. |
| 254 | + if not supports_schema: |
209 | 255 | raise ValueError("Model does not support response_format") |
210 | 256 |
|
211 | | - response = await litellm.acompletion( |
212 | | - **self.client_args, |
213 | | - model=self.get_config()["model_id"], |
214 | | - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
215 | | - response_format=output_model, |
216 | | - ) |
| 257 | + # For providers that DO support response schemas, call litellm and map context-window errors. |
| 258 | + try: |
| 259 | + response = await litellm.acompletion( |
| 260 | + **self.client_args, |
| 261 | + model=self.get_config()["model_id"], |
| 262 | + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
| 263 | + response_format=output_model, |
| 264 | + ) |
| 265 | + except Exception as e: |
| 266 | + self._handle_context_window_overflow(e) |
217 | 267 |
|
218 | 268 | if len(response.choices) > 1: |
219 | 269 | raise ValueError("Multiple choices found in the response.") |
|
0 commit comments