Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
self._repeated_messages_count = 1
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
self.created: Optional[int] = None

Expand Down Expand Up @@ -190,37 +191,44 @@ def process_chunk(self, chunk: str):
except Exception as e:
raise e

def safety_checker(self) -> None:
def raise_on_model_repetition(self) -> None:
"""
Fixes - https://github.com/BerriAI/litellm/issues/5158

if the model enters a loop and starts repeating the same chunk again, break out of loop and raise an internalservererror - allows for retries.

Raises - InternalServerError, if LLM enters infinite loop while streaming
"""
if len(self.chunks) >= litellm.REPEATED_STREAMING_CHUNK_LIMIT:
# Get the last n chunks
last_chunks = self.chunks[-litellm.REPEATED_STREAMING_CHUNK_LIMIT :]
if len(self.chunks) < 2:
return

# Extract the relevant content from the chunks
last_contents = [chunk.choices[0].delta.content for chunk in last_chunks]
last_content = self.chunks[-1].choices[0].delta.content

# Check if all extracted contents are identical
if all(content == last_contents[0] for content in last_contents):
if (
last_contents[0] is not None
and isinstance(last_contents[0], str)
and len(last_contents[0]) > 2
): # ignore empty content - https://github.com/BerriAI/litellm/issues/5158#issuecomment-2287156946
# All last n chunks are identical
raise litellm.InternalServerError(
message="The model is repeating the same chunk = {}.".format(
last_contents[0]
),
model="",
llm_provider="",
)
if (
last_content is None
or not isinstance(last_content, str)
or len(last_content) <= 2
): # ignore empty content - https://github.com/BerriAI/litellm/issues/5158#issuecomment-2287156946
self._repeated_messages_count = 1
return

second_to_last_content = self.chunks[-2].choices[0].delta.content

if last_content == second_to_last_content:
self._repeated_messages_count += 1
else:
self._repeated_messages_count = 1

if self._repeated_messages_count >= litellm.REPEATED_STREAMING_CHUNK_LIMIT:
# All last n chunks are identical
raise litellm.InternalServerError(
message="The model is repeating the same chunk = {}.".format(
last_content
),
model="",
llm_provider="",
)

def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
"""
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
Expand Down Expand Up @@ -879,7 +887,7 @@ def return_processed_chunk_logic( # noqa
if (
is_chunk_non_empty
): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker()
self.raise_on_model_repetition()
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
finish_reason=model_response.choices[0].finish_reason,
Expand Down
117 changes: 117 additions & 0 deletions tests/test_litellm/litellm_core_utils/test_streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,120 @@ def test_is_chunk_non_empty_with_valid_tool_calls(
)
is True
)


def _make_chunk(content: Optional[str]) -> ModelResponseStream:
return ModelResponseStream(
id="test",
created=1741037890,
model="test-model",
choices=[StreamingChoices(index=0, delta=Delta(content=content))],
)


def _build_chunks(pattern: list[str], N: int) -> list[ModelResponseStream]:
"""
Build a list of chunks based on a pattern specification.
"""
chunks = []
for i, p in enumerate(pattern):
if p == "same":
chunks.append(_make_chunk("same_chunk"))
elif p == "diff":
chunks.append(_make_chunk(f"chunk_{i}"))
else:
chunks.append(_make_chunk(p))
return chunks

_REPETITION_TEST_CASES = [
# Basic cases
pytest.param(
["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
True,
id="all_identical_raises",
),
pytest.param(
["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1),
False,
id="below_threshold_no_raise",
),
pytest.param(
[None] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="none_content_no_raise",
),
pytest.param(
[""] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="empty_content_no_raise",
),
# Short content (len <= 2) should not raise
pytest.param(
["##"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="short_content_2chars_no_raise",
),
pytest.param(
["{"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="short_content_1char_no_raise",
),
pytest.param(
["ab"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="short_content_2chars_ab_no_raise",
),
# All different chunks
pytest.param(
["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT,
False,
id="all_different_no_raise",
),
# One chunk different at various positions
pytest.param(
["different_first"] + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1),
False,
id="first_chunk_different_no_raise",
),
pytest.param(
["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1) + ["different_last"],
False,
id="last_chunk_different_no_raise",
),
pytest.param(
["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1) + ["different_mid"] + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1),
False,
id="middle_chunk_different_no_raise",
),
pytest.param(
["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 2) + ["diff", "diff"],
False,
id="last_two_different_no_raise",
),
pytest.param(
["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["diff"],
True,
id="in_between_same_and_diff_raise",
),
]


@pytest.mark.parametrize("chunks_pattern,should_raise", _REPETITION_TEST_CASES)
def test_raise_on_model_repetition(
initialized_custom_stream_wrapper: CustomStreamWrapper,
chunks_pattern: list,
should_raise: bool,
):
wrapper = initialized_custom_stream_wrapper
chunks = _build_chunks(chunks_pattern, len(chunks_pattern))

if should_raise:
with pytest.raises(litellm.InternalServerError) as exc_info:
for chunk in chunks:
wrapper.chunks.append(chunk)
wrapper.raise_on_model_repetition()
assert "repeating the same chunk" in str(exc_info.value)
else:
for chunk in chunks:
wrapper.chunks.append(chunk)
wrapper.raise_on_model_repetition()
Loading