Skip to content

Commit 93f2eb6

Browse files
authored
models - ollama - async (#373)
1 parent a06b9b1 commit 93f2eb6

File tree

3 files changed

+76
-30
lines changed

3 files changed

+76
-30
lines changed

src/strands/models/ollama.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast
99

10-
from ollama import Client as OllamaClient
10+
import ollama
1111
from pydantic import BaseModel
1212
from typing_extensions import TypedDict, Unpack, override
1313

@@ -74,7 +74,7 @@ def __init__(
7474

7575
ollama_client_args = ollama_client_args if ollama_client_args is not None else {}
7676

77-
self.client = OllamaClient(host, **ollama_client_args)
77+
self.client = ollama.AsyncClient(host, **ollama_client_args)
7878

7979
@override
8080
def update_config(self, **model_config: Unpack[OllamaConfig]) -> None: # type: ignore
@@ -296,12 +296,12 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
296296
"""
297297
tool_requested = False
298298

299-
response = self.client.chat(**request)
299+
response = await self.client.chat(**request)
300300

301301
yield {"chunk_type": "message_start"}
302302
yield {"chunk_type": "content_start", "data_type": "text"}
303303

304-
for event in response:
304+
async for event in response:
305305
for tool_call in event.message.tool_calls or []:
306306
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call}
307307
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}
@@ -330,7 +330,7 @@ async def structured_output(
330330
formatted_request = self.format_request(messages=prompt)
331331
formatted_request["format"] = output_model.model_json_schema()
332332
formatted_request["stream"] = False
333-
response = self.client.chat(**formatted_request)
333+
response = await self.client.chat(**formatted_request)
334334

335335
try:
336336
content = response.message.content.strip()

tests-integ/test_model_ollama.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import requests
33
from pydantic import BaseModel
44

5+
import strands
56
from strands import Agent
67
from strands.models.ollama import OllamaModel
78

@@ -13,35 +14,80 @@ def is_server_available() -> bool:
1314
return False
1415

1516

16-
@pytest.fixture
17+
@pytest.fixture(scope="module")
1718
def model():
1819
return OllamaModel(host="http://localhost:11434", model_id="llama3.3:70b")
1920

2021

21-
@pytest.fixture
22-
def agent(model):
23-
return Agent(model=model)
22+
@pytest.fixture(scope="module")
23+
def tools():
24+
@strands.tool
25+
def tool_time() -> str:
26+
return "12:00"
2427

28+
@strands.tool
29+
def tool_weather() -> str:
30+
return "sunny"
2531

26-
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
27-
def test_agent(agent):
28-
result = agent("Say 'hello world' with no other text")
29-
assert isinstance(result.message["content"][0]["text"], str)
32+
return [tool_time, tool_weather]
3033

3134

32-
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
33-
def test_structured_output(agent):
34-
class Weather(BaseModel):
35-
"""Extract the time and weather.
35+
@pytest.fixture(scope="module")
36+
def agent(model, tools):
37+
return Agent(model=model, tools=tools)
3638

37-
Time format: HH:MM
38-
Weather: sunny, cloudy, rainy, etc.
39-
"""
39+
40+
@pytest.fixture(scope="module")
41+
def weather():
42+
class Weather(BaseModel):
43+
"""Extracts the time and weather from the user's message with the exact strings."""
4044

4145
time: str
4246
weather: str
4347

44-
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
45-
assert isinstance(result, Weather)
46-
assert result.time == "12:00"
47-
assert result.weather == "sunny"
48+
return Weather(time="12:00", weather="sunny")
49+
50+
51+
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
52+
def test_agent_invoke(agent):
53+
result = agent("What is the time and weather in New York?")
54+
text = result.message["content"][0]["text"].lower()
55+
56+
assert all(string in text for string in ["12:00", "sunny"])
57+
58+
59+
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
60+
@pytest.mark.asyncio
61+
async def test_agent_invoke_async(agent):
62+
result = await agent.invoke_async("What is the time and weather in New York?")
63+
text = result.message["content"][0]["text"].lower()
64+
65+
assert all(string in text for string in ["12:00", "sunny"])
66+
67+
68+
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
69+
@pytest.mark.asyncio
70+
async def test_agent_stream_async(agent):
71+
stream = agent.stream_async("What is the time and weather in New York?")
72+
async for event in stream:
73+
_ = event
74+
75+
result = event["result"]
76+
text = result.message["content"][0]["text"].lower()
77+
78+
assert all(string in text for string in ["12:00", "sunny"])
79+
80+
81+
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
82+
def test_agent_structured_output(agent, weather):
83+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
84+
exp_weather = weather
85+
assert tru_weather == exp_weather
86+
87+
88+
@pytest.mark.skipif(not is_server_available(), reason="Local Ollama endpoint not available at localhost:11434")
89+
@pytest.mark.asyncio
90+
async def test_agent_structured_output_async(agent, weather):
91+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
92+
exp_weather = weather
93+
assert tru_weather == exp_weather

tests/strands/models/test_ollama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@pytest.fixture
1313
def ollama_client():
14-
with unittest.mock.patch.object(strands.models.ollama, "OllamaClient") as mock_client_cls:
14+
with unittest.mock.patch.object(strands.models.ollama.ollama, "AsyncClient") as mock_client_cls:
1515
yield mock_client_cls.return_value
1616

1717

@@ -416,13 +416,13 @@ def test_format_chunk_other(model):
416416

417417

418418
@pytest.mark.asyncio
419-
async def test_stream(ollama_client, model, alist):
419+
async def test_stream(ollama_client, model, agenerator, alist):
420420
mock_event = unittest.mock.Mock()
421421
mock_event.message.tool_calls = None
422422
mock_event.message.content = "Hello"
423423
mock_event.done_reason = "stop"
424424

425-
ollama_client.chat.return_value = [mock_event]
425+
ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))
426426

427427
request = {"model": "m1", "messages": [{"role": "user", "content": "Hello"}]}
428428
response = model.stream(request)
@@ -442,14 +442,14 @@ async def test_stream(ollama_client, model, alist):
442442

443443

444444
@pytest.mark.asyncio
445-
async def test_stream_with_tool_calls(ollama_client, model, alist):
445+
async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist):
446446
mock_event = unittest.mock.Mock()
447447
mock_tool_call = unittest.mock.Mock()
448448
mock_event.message.tool_calls = [mock_tool_call]
449449
mock_event.message.content = "I'll calculate that for you"
450450
mock_event.done_reason = "stop"
451451

452-
ollama_client.chat.return_value = [mock_event]
452+
ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([mock_event]))
453453

454454
request = {"model": "m1", "messages": [{"role": "user", "content": "Calculate 2+2"}]}
455455
response = model.stream(request)
@@ -478,7 +478,7 @@ async def test_structured_output(ollama_client, model, test_output_model_cls, al
478478
mock_response = unittest.mock.Mock()
479479
mock_response.message.content = '{"name": "John", "age": 30}'
480480

481-
ollama_client.chat.return_value = mock_response
481+
ollama_client.chat = unittest.mock.AsyncMock(return_value=mock_response)
482482

483483
stream = model.structured_output(test_output_model_cls, messages)
484484
events = await alist(stream)

0 commit comments

Comments
 (0)