Skip to content

Commit c734c01

Browse files
committed
fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException
1 parent 776fd93 commit c734c01

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

src/strands/models/litellm.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
16+
from ..types.exceptions import ContextWindowOverflowException
1617
from ..types.streaming import StreamEvent
1718
from ..types.tools import ToolChoice, ToolSpec
1819
from ._validation import validate_config_keys
@@ -22,6 +23,18 @@
2223

2324
T = TypeVar("T", bound=BaseModel)
2425

26+
# See: https://github.com/BerriAI/litellm/blob/main/litellm/exceptions.py
27+
# The following are common substrings found in context window related errors
28+
# from various models proxied by LiteLLM.
29+
LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
30+
"Context Window Error",
31+
"Context Window Exceeded",
32+
"ContextWindowExceeded",
33+
"Context window exceeded",
34+
"Input is too long",
35+
"ContextWindowExceededError",
36+
]
37+
2538

2639
class LiteLLMModel(OpenAIModel):
2740
"""LiteLLM model provider implementation."""
@@ -56,6 +69,32 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
5669

5770
logger.debug("config=<%s> | initializing", self.config)
5871

72+
def _handle_context_window_overflow(self, e: Exception) -> None:
73+
"""Handle context window overflow errors from LiteLLM.
74+
75+
Args:
76+
e: The exception to handle.
77+
78+
Raises:
79+
ContextWindowOverflowException: If the exception is a context window overflow error.
80+
"""
81+
# Prefer litellm-specific typed exception if exposed
82+
litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr(
83+
litellm, "ContextWindowExceeded", None
84+
)
85+
if litellm_exc_type and isinstance(e, litellm_exc_type):
86+
logger.warning("litellm client raised context window overflow")
87+
raise ContextWindowOverflowException(e) from e
88+
89+
# Fallback to substring checks similar to Bedrock handling
90+
error_message = str(e)
91+
if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
92+
logger.warning("litellm threw context window overflow error")
93+
raise ContextWindowOverflowException(e) from e
94+
95+
# Not a context-window error — re-raise original
96+
raise e
97+
5998
@override
6099
def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
61100
"""Update the LiteLLM model configuration with the provided arguments.
@@ -135,7 +174,10 @@ async def stream(
135174
logger.debug("request=<%s>", request)
136175

137176
logger.debug("invoking model")
138-
response = await litellm.acompletion(**self.client_args, **request)
177+
try:
178+
response = await litellm.acompletion(**self.client_args, **request)
179+
except Exception as e:
180+
self._handle_context_window_overflow(e)
139181

140182
logger.debug("got response from model")
141183
yield self.format_chunk({"chunk_type": "message_start"})
@@ -205,15 +247,23 @@ async def structured_output(
205247
Yields:
206248
Model events with the last being the structured output.
207249
"""
208-
if not supports_response_schema(self.get_config()["model_id"]):
250+
supports_schema = supports_response_schema(self.get_config()["model_id"])
251+
252+
# If the provider does not support response schemas, we cannot reliably parse structured output.
253+
# In that case we must not call the provider and must raise the documented ValueError.
254+
if not supports_schema:
209255
raise ValueError("Model does not support response_format")
210256

211-
response = await litellm.acompletion(
212-
**self.client_args,
213-
model=self.get_config()["model_id"],
214-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
215-
response_format=output_model,
216-
)
257+
# For providers that DO support response schemas, call litellm and map context-window errors.
258+
try:
259+
response = await litellm.acompletion(
260+
**self.client_args,
261+
model=self.get_config()["model_id"],
262+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
263+
response_format=output_model,
264+
)
265+
except Exception as e:
266+
self._handle_context_window_overflow(e)
217267

218268
if len(response.choices) > 1:
219269
raise ValueError("Multiple choices found in the response.")

tests/strands/models/test_litellm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import strands
88
from strands.models.litellm import LiteLLMModel
9+
from strands.types.exceptions import ContextWindowOverflowException
910

1011

1112
@pytest.fixture
@@ -301,6 +302,32 @@ async def test_structured_output_unsupported_model(litellm_acompletion, model, t
301302
litellm_acompletion.assert_not_called()
302303

303304

305+
@pytest.mark.asyncio
306+
async def test_stream_context_window_maps_to_exception(litellm_acompletion, model):
307+
# Make the litellm client raise an error that indicates a context-window overflow.
308+
litellm_acompletion.side_effect = Exception("Input is too long for requested model")
309+
310+
with pytest.raises(ContextWindowOverflowException):
311+
async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]):
312+
pass
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_structured_output_context_window_maps_to_exception(litellm_acompletion, model, test_output_model_cls):
317+
# Litellm structured_output path raising similar message should be mapped too.
318+
litellm_acompletion.side_effect = Exception("Context Window Error - Input too long")
319+
320+
# Ensure supports_response_schema returns True so structured_output will call litellm.acompletion
321+
# and we can observe mapping to ContextWindowOverflowException.
322+
with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True):
323+
with pytest.raises(ContextWindowOverflowException):
324+
# structured_output is async generator; consuming it should raise our mapped exception.
325+
async for _ in model.structured_output(
326+
output_model=test_output_model_cls, prompt=[{"role": "user", "content": [{"text": "x"}]}]
327+
):
328+
pass
329+
330+
304331
def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
305332
"""Test that unknown config keys emit a warning."""
306333
LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test")

0 commit comments

Comments
 (0)