Skip to content

Commit 74d8361

Browse files
GWealecopybara-github
authored andcommitted
fix: Add a fallback user message to LiteLLM requests if the last user message is empty
Related to #3255 Close #2560 PiperOrigin-RevId: 825143315
1 parent 240ef5b commit 74d8361

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,53 @@ def _safe_json_serialize(obj) -> str:
155155
return str(obj)
156156

157157

158+
def _part_has_payload(part: types.Part) -> bool:
159+
"""Checks whether a Part contains usable payload for the model."""
160+
if part.text:
161+
return True
162+
if part.inline_data and part.inline_data.data:
163+
return True
164+
if part.file_data and (part.file_data.file_uri or part.file_data.data):
165+
return True
166+
return False
167+
168+
169+
def _append_fallback_user_content_if_missing(
170+
llm_request: LlmRequest,
171+
) -> None:
172+
"""Ensures there is a user message with content for LiteLLM backends.
173+
174+
Args:
175+
llm_request: The request that may need a fallback user message.
176+
"""
177+
for content in reversed(llm_request.contents):
178+
if content.role == "user":
179+
parts = content.parts or []
180+
if any(_part_has_payload(part) for part in parts):
181+
return
182+
if not parts:
183+
content.parts = []
184+
content.parts.append(
185+
types.Part.from_text(
186+
text="Handle the requests as specified in the System Instruction."
187+
)
188+
)
189+
return
190+
llm_request.contents.append(
191+
types.Content(
192+
role="user",
193+
parts=[
194+
types.Part.from_text(
195+
text=(
196+
"Handle the requests as specified in the System"
197+
" Instruction."
198+
)
199+
),
200+
],
201+
)
202+
)
203+
204+
158205
def _content_to_message_param(
159206
content: types.Content,
160207
) -> Union[Message, list[Message]]:
@@ -818,6 +865,7 @@ async def generate_content_async(
818865
"""
819866

820867
self._maybe_append_user_content(llm_request)
868+
_append_fallback_user_content_if_missing(llm_request)
821869
logger.debug(_build_request_log(llm_request))
822870

823871
messages, tools, response_format, generation_params = (

tests/unittests/models/test_litellm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,41 @@ async def test_generate_content_async(mock_acompletion, lite_llm_instance):
548548
)
549549

550550

551+
@pytest.mark.asyncio
552+
async def test_generate_content_async_adds_fallback_user_message(
553+
mock_acompletion, lite_llm_instance
554+
):
555+
llm_request = LlmRequest(
556+
contents=[
557+
types.Content(
558+
role="user",
559+
parts=[],
560+
)
561+
]
562+
)
563+
564+
async for _ in lite_llm_instance.generate_content_async(llm_request):
565+
pass
566+
567+
mock_acompletion.assert_called_once()
568+
569+
_, kwargs = mock_acompletion.call_args
570+
user_messages = [
571+
message for message in kwargs["messages"] if message["role"] == "user"
572+
]
573+
assert any(
574+
message.get("content")
575+
== "Handle the requests as specified in the System Instruction."
576+
for message in user_messages
577+
)
578+
assert (
579+
sum(1 for content in llm_request.contents if content.role == "user") == 1
580+
)
581+
assert llm_request.contents[-1].parts[0].text == (
582+
"Handle the requests as specified in the System Instruction."
583+
)
584+
585+
551586
litellm_append_user_content_test_cases = [
552587
pytest.param(
553588
LlmRequest(

0 commit comments

Comments
 (0)