Skip to content

Commit ffcff42

Browse files
committed
test: integration tests for palmyra provider
1 parent 0303a8d commit ffcff42

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

tests-integ/test_model_palmyra.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
3+
import pytest
4+
5+
import strands
6+
from strands import Agent
7+
from strands.models.palmyra import PalmyraModel
8+
9+
if not os.getenv("WRITER_API_KEY"):
10+
pytest.skip("WRITER_API_KEY environment variable missing", allow_module_level=True)
11+
12+
13+
@pytest.fixture
14+
def model():
15+
return PalmyraModel(
16+
model="palmyra-x5",
17+
client_args={"api_key": os.getenv("WRITER_API_KEY", "")},
18+
stream_options={"include_usage": True},
19+
)
20+
21+
22+
@pytest.fixture
23+
def system_prompt():
24+
return "You are a smart assistant, that uses @ instead of all punctuation marks. It is an obligation!"
25+
26+
27+
@pytest.fixture
28+
def tools():
29+
@strands.tool
30+
def tool_time(location: str) -> str:
31+
"""Returning time in the specific location.
32+
33+
Args:
34+
location: Location to return time at.
35+
"""
36+
37+
return "12:00"
38+
39+
@strands.tool
40+
def tool_weather(location: str, time: str) -> str:
41+
"""Returning weather in the specific location and time.
42+
43+
Args:
44+
location: Location to return weather at.
45+
time: Moment of time to return weather in specified location.
46+
"""
47+
48+
return "sunny"
49+
50+
return [tool_time, tool_weather]
51+
52+
53+
@pytest.fixture
54+
def agent(model, tools, system_prompt):
55+
return Agent(model=model, tools=tools, system_prompt=system_prompt, load_tools_from_directory=False)
56+
57+
58+
def test_agent(agent):
59+
response = agent("How are you?")
60+
61+
assert len(response.message) > 0
62+
assert "@" in response.message.get("content", [])[0].get("text", "")
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_async_streaming_agent(agent):
67+
response = agent.stream_async("How are you?")
68+
69+
full_message = ""
70+
async for event in response:
71+
if delta_text := event.get("event", {}).get("contentBlockDelta", {}).get("delta", {}).get("text", ""):
72+
full_message += delta_text
73+
74+
assert len(full_message) > 0
75+
assert "@" in full_message
76+
77+
78+
def test_model_events(model):
79+
messages = [{"role": "user", "content": [{"text": "How are you?"}]}]
80+
81+
response_events = {key for x in model.converse(messages) for key in x.keys()}
82+
83+
assert all(
84+
[
85+
event_type in response_events
86+
for event_type in [
87+
"messageStart",
88+
"contentBlockStart",
89+
"contentBlockDelta",
90+
"contentBlockStop",
91+
"messageStop",
92+
"metadata",
93+
]
94+
]
95+
)
96+
97+
98+
def test_agent_with_tool_calls(agent, model):
99+
model.update_config(model="palmyra-x4")
100+
agent.system_prompt = ""
101+
102+
response = agent("What is the time and weather in Warsaw?")
103+
response_message_text = response.message.get("content", [])[0].get("text", "")
104+
105+
assert len(response.message) > 0
106+
assert "12" in response_message_text
107+
assert "sunny" in response_message_text
108+
assert all(tool in response.metrics.tool_metrics for tool in ["tool_time", "tool_weather"])
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_async_streaming_agent_with_tool_calls(agent, model):
113+
model.update_config(model="palmyra-x4")
114+
agent.system_prompt = ""
115+
116+
response = agent.stream_async("What is the time and weather in Warsaw?")
117+
118+
full_message = ""
119+
async for event in response:
120+
if delta_text := event.get("event", {}).get("contentBlockDelta", {}).get("delta", {}).get("text", ""):
121+
full_message += delta_text
122+
123+
assert len(full_message) > 0
124+
assert "12" in full_message
125+
assert "sunny" in full_message

0 commit comments

Comments
 (0)