|
13 | 13 | from typing_extensions import Unpack, override |
14 | 14 |
|
15 | 15 | from ..types.content import ContentBlock, Messages |
| 16 | +from ..types.exceptions import ContextWindowOverflowException |
16 | 17 | from ..types.streaming import StreamEvent |
17 | 18 | from ..types.tools import ToolChoice, ToolSpec |
18 | 19 | from ._validation import validate_config_keys |
@@ -56,6 +57,24 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: |
56 | 57 |
|
57 | 58 | logger.debug("config=<%s> | initializing", self.config) |
58 | 59 |
|
| 60 | + def _handle_context_window_overflow(self, e: Exception) -> None: |
| 61 | + """Handle context window overflow errors from LiteLLM. |
| 62 | +
|
| 63 | + Args: |
| 64 | + e: The exception to handle. |
| 65 | +
|
| 66 | + Raises: |
| 67 | + ContextWindowOverflowException: If the exception is a context window overflow error. |
| 68 | + """ |
| 69 | + # Prefer litellm-specific typed exception if exposed |
| 70 | + litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) |
| 71 | + if litellm_exc_type and isinstance(e, litellm_exc_type): |
| 72 | + logger.warning("litellm client raised context window overflow") |
| 73 | + raise ContextWindowOverflowException(e) from e |
| 74 | + |
| 75 | + # Not a context-window error — re-raise original |
| 76 | + raise e |
| 77 | + |
59 | 78 | @override |
60 | 79 | def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override] |
61 | 80 | """Update the LiteLLM model configuration with the provided arguments. |
@@ -135,7 +154,10 @@ async def stream( |
135 | 154 | logger.debug("request=<%s>", request) |
136 | 155 |
|
137 | 156 | logger.debug("invoking model") |
138 | | - response = await litellm.acompletion(**self.client_args, **request) |
| 157 | + try: |
| 158 | + response = await litellm.acompletion(**self.client_args, **request) |
| 159 | + except Exception as e: |
| 160 | + self._handle_context_window_overflow(e) |
139 | 161 |
|
140 | 162 | logger.debug("got response from model") |
141 | 163 | yield self.format_chunk({"chunk_type": "message_start"}) |
@@ -205,15 +227,23 @@ async def structured_output( |
205 | 227 | Yields: |
206 | 228 | Model events with the last being the structured output. |
207 | 229 | """ |
208 | | - if not supports_response_schema(self.get_config()["model_id"]): |
| 230 | + supports_schema = supports_response_schema(self.get_config()["model_id"]) |
| 231 | + |
| 232 | + # If the provider does not support response schemas, we cannot reliably parse structured output. |
| 233 | + # In that case we must not call the provider and must raise the documented ValueError. |
| 234 | + if not supports_schema: |
209 | 235 | raise ValueError("Model does not support response_format") |
210 | 236 |
|
211 | | - response = await litellm.acompletion( |
212 | | - **self.client_args, |
213 | | - model=self.get_config()["model_id"], |
214 | | - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
215 | | - response_format=output_model, |
216 | | - ) |
| 237 | + # For providers that DO support response schemas, call litellm and map context-window errors. |
| 238 | + try: |
| 239 | + response = await litellm.acompletion( |
| 240 | + **self.client_args, |
| 241 | + model=self.get_config()["model_id"], |
| 242 | + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
| 243 | + response_format=output_model, |
| 244 | + ) |
| 245 | + except Exception as e: |
| 246 | + self._handle_context_window_overflow(e) |
217 | 247 |
|
218 | 248 | if len(response.choices) > 1: |
219 | 249 | raise ValueError("Multiple choices found in the response.") |
|
0 commit comments