Skip to content

Commit 648d0da

Browse files
committed
test: Add test for content_to_message_param
1 parent f70eeb3 commit 648d0da

File tree

1 file changed

+173
-97
lines changed

1 file changed

+173
-97
lines changed

tests/unittests/models/test_anthropic_llm.py

Lines changed: 173 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.adk import version as adk_version
2121
from google.adk.models import anthropic_llm
2222
from google.adk.models.anthropic_llm import Claude
23+
from google.adk.models.anthropic_llm import content_to_message_param
2324
from google.adk.models.anthropic_llm import function_declaration_to_tool_param
2425
from google.adk.models.llm_request import LlmRequest
2526
from google.adk.models.llm_response import LlmResponse
@@ -32,69 +33,69 @@
3233

3334
@pytest.fixture
3435
def generate_content_response():
35-
return anthropic_types.Message(
36-
id="msg_vrtx_testid",
37-
content=[
38-
anthropic_types.TextBlock(
39-
citations=None, text="Hi! How can I help you today?", type="text"
40-
)
41-
],
42-
model="claude-3-5-sonnet-v2-20241022",
43-
role="assistant",
44-
stop_reason="end_turn",
45-
stop_sequence=None,
46-
type="message",
47-
usage=anthropic_types.Usage(
48-
cache_creation_input_tokens=0,
49-
cache_read_input_tokens=0,
50-
input_tokens=13,
51-
output_tokens=12,
52-
server_tool_use=None,
53-
service_tier=None,
54-
),
55-
)
36+
return anthropic_types.Message(
37+
id="msg_vrtx_testid",
38+
content=[
39+
anthropic_types.TextBlock(
40+
citations=None, text="Hi! How can I help you today?", type="text"
41+
)
42+
],
43+
model="claude-3-5-sonnet-v2-20241022",
44+
role="assistant",
45+
stop_reason="end_turn",
46+
stop_sequence=None,
47+
type="message",
48+
usage=anthropic_types.Usage(
49+
cache_creation_input_tokens=0,
50+
cache_read_input_tokens=0,
51+
input_tokens=13,
52+
output_tokens=12,
53+
server_tool_use=None,
54+
service_tier=None,
55+
),
56+
)
5657

5758

5859
@pytest.fixture
5960
def generate_llm_response():
60-
return LlmResponse.create(
61-
types.GenerateContentResponse(
62-
candidates=[
63-
types.Candidate(
64-
content=Content(
65-
role="model",
66-
parts=[Part.from_text(text="Hello, how can I help you?")],
67-
),
68-
finish_reason=types.FinishReason.STOP,
69-
)
70-
]
71-
)
72-
)
61+
return LlmResponse.create(
62+
types.GenerateContentResponse(
63+
candidates=[
64+
types.Candidate(
65+
content=Content(
66+
role="model",
67+
parts=[Part.from_text(text="Hello, how can I help you?")],
68+
),
69+
finish_reason=types.FinishReason.STOP,
70+
)
71+
]
72+
)
73+
)
7374

7475

7576
@pytest.fixture
7677
def claude_llm():
77-
return Claude(model="claude-3-5-sonnet-v2@20241022")
78+
return Claude(model="claude-3-5-sonnet-v2@20241022")
7879

7980

8081
@pytest.fixture
8182
def llm_request():
82-
return LlmRequest(
83-
model="claude-3-5-sonnet-v2@20241022",
84-
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
85-
config=types.GenerateContentConfig(
86-
temperature=0.1,
87-
response_modalities=[types.Modality.TEXT],
88-
system_instruction="You are a helpful assistant",
89-
),
90-
)
83+
return LlmRequest(
84+
model="claude-3-5-sonnet-v2@20241022",
85+
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
86+
config=types.GenerateContentConfig(
87+
temperature=0.1,
88+
response_modalities=[types.Modality.TEXT],
89+
system_instruction="You are a helpful assistant",
90+
),
91+
)
9192

9293

9394
def test_supported_models():
94-
models = Claude.supported_models()
95-
assert len(models) == 2
96-
assert models[0] == r"claude-3-.*"
97-
assert models[1] == r"claude-.*-4.*"
95+
models = Claude.supported_models()
96+
assert len(models) == 2
97+
assert models[0] == r"claude-3-.*"
98+
assert models[1] == r"claude-.*-4.*"
9899

99100

100101
function_declaration_test_cases = [
@@ -133,9 +134,7 @@ def test_supported_models():
133134
"properties": {
134135
"location": {
135136
"type": "string",
136-
"description": (
137-
"City and state, e.g., San Francisco, CA"
138-
),
137+
"description": ("City and state, e.g., San Francisco, CA"),
139138
}
140139
},
141140
},
@@ -284,65 +283,142 @@ def test_supported_models():
284283
async def test_function_declaration_to_tool_param(
285284
_, function_declaration, expected_tool_param
286285
):
287-
"""Test function_declaration_to_tool_param."""
288-
assert (
289-
function_declaration_to_tool_param(function_declaration)
290-
== expected_tool_param
291-
)
286+
"""Test function_declaration_to_tool_param."""
287+
assert (
288+
function_declaration_to_tool_param(function_declaration) == expected_tool_param
289+
)
292290

293291

294292
@pytest.mark.asyncio
295293
async def test_generate_content_async(
296294
claude_llm, llm_request, generate_content_response, generate_llm_response
297295
):
298-
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
299-
with mock.patch.object(
300-
anthropic_llm,
301-
"message_to_generate_content_response",
302-
return_value=generate_llm_response,
303-
):
304-
# Create a mock coroutine that returns the generate_content_response.
305-
async def mock_coro():
306-
return generate_content_response
296+
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
297+
with mock.patch.object(
298+
anthropic_llm,
299+
"message_to_generate_content_response",
300+
return_value=generate_llm_response,
301+
):
302+
# Create a mock coroutine that returns the generate_content_response.
303+
async def mock_coro():
304+
return generate_content_response
307305

308-
# Assign the coroutine to the mocked method
309-
mock_client.messages.create.return_value = mock_coro()
306+
# Assign the coroutine to the mocked method
307+
mock_client.messages.create.return_value = mock_coro()
310308

311-
responses = [
312-
resp
313-
async for resp in claude_llm.generate_content_async(
314-
llm_request, stream=False
315-
)
316-
]
317-
assert len(responses) == 1
318-
assert isinstance(responses[0], LlmResponse)
319-
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
309+
responses = [
310+
resp
311+
async for resp in claude_llm.generate_content_async(
312+
llm_request, stream=False
313+
)
314+
]
315+
assert len(responses) == 1
316+
assert isinstance(responses[0], LlmResponse)
317+
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
320318

321319

322320
@pytest.mark.asyncio
323321
async def test_generate_content_async_with_max_tokens(
324322
llm_request, generate_content_response, generate_llm_response
325323
):
326-
claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096)
327-
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
328-
with mock.patch.object(
329-
anthropic_llm,
330-
"message_to_generate_content_response",
331-
return_value=generate_llm_response,
332-
):
333-
# Create a mock coroutine that returns the generate_content_response.
334-
async def mock_coro():
335-
return generate_content_response
324+
claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096)
325+
with mock.patch.object(claude_llm, "_anthropic_client") as mock_client:
326+
with mock.patch.object(
327+
anthropic_llm,
328+
"message_to_generate_content_response",
329+
return_value=generate_llm_response,
330+
):
331+
# Create a mock coroutine that returns the generate_content_response.
332+
async def mock_coro():
333+
return generate_content_response
334+
335+
# Assign the coroutine to the mocked method
336+
mock_client.messages.create.return_value = mock_coro()
337+
338+
_ = [
339+
resp
340+
async for resp in claude_llm.generate_content_async(
341+
llm_request, stream=False
342+
)
343+
]
344+
mock_client.messages.create.assert_called_once()
345+
_, kwargs = mock_client.messages.create.call_args
346+
assert kwargs["max_tokens"] == 4096
347+
348+
349+
content_to_message_param_test_cases = [
350+
(
351+
"user_role_with_text_and_image",
352+
Content(
353+
role="user",
354+
parts=[
355+
Part.from_text(text="What's in this image?"),
356+
Part(
357+
inline_data=types.Blob(
358+
mime_type="image/jpeg", data=b"fake_image_data"
359+
)
360+
),
361+
],
362+
),
363+
"user",
364+
2, # Expected content length
365+
False, # Should not log warning
366+
),
367+
(
368+
"model_role_with_text_and_image",
369+
Content(
370+
role="model",
371+
parts=[
372+
Part.from_text(text="I see a cat."),
373+
Part(
374+
inline_data=types.Blob(
375+
mime_type="image/png", data=b"fake_image_data"
376+
)
377+
),
378+
],
379+
),
380+
"assistant",
381+
1, # Image filtered out, only text remains
382+
True, # Should log warning
383+
),
384+
(
385+
"assistant_role_with_text_and_image",
386+
Content(
387+
role="assistant",
388+
parts=[
389+
Part.from_text(text="Here's what I found."),
390+
Part(
391+
inline_data=types.Blob(
392+
mime_type="image/webp", data=b"fake_image_data"
393+
)
394+
),
395+
],
396+
),
397+
"assistant",
398+
1, # Image filtered out, only text remains
399+
True, # Should log warning
400+
),
401+
]
402+
403+
404+
@pytest.mark.parametrize(
405+
"_, content, expected_role, expected_content_length, should_log_warning",
406+
content_to_message_param_test_cases,
407+
ids=[case[0] for case in content_to_message_param_test_cases],
408+
)
409+
def test_content_to_message_param_with_images(
410+
_, content, expected_role, expected_content_length, should_log_warning
411+
):
412+
"""Test content_to_message_param handles images correctly based on role."""
413+
with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger:
414+
result = content_to_message_param(content)
336415

337-
# Assign the coroutine to the mocked method
338-
mock_client.messages.create.return_value = mock_coro()
416+
assert result["role"] == expected_role
417+
assert len(result["content"]) == expected_content_length
339418

340-
_ = [
341-
resp
342-
async for resp in claude_llm.generate_content_async(
343-
llm_request, stream=False
344-
)
345-
]
346-
mock_client.messages.create.assert_called_once()
347-
_, kwargs = mock_client.messages.create.call_args
348-
assert kwargs["max_tokens"] == 4096
419+
if should_log_warning:
420+
mock_logger.warning.assert_called_once_with(
421+
"Image data is not supported in Claude for model turns."
422+
)
423+
else:
424+
mock_logger.warning.assert_not_called()

0 commit comments

Comments
 (0)