Skip to content

Commit 6a9eb70

Browse files
Dominic-Wagnerchr-hertel
authored andcommitted
[Platform] add ollama toolcall support for streaming
1 parent beadd5c commit 6a9eb70

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <fabien@symfony.com>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Symfony\AI\Agent\Agent;
13+
use Symfony\AI\Agent\Toolbox\AgentProcessor;
14+
use Symfony\AI\Agent\Toolbox\Tool\Clock;
15+
use Symfony\AI\Agent\Toolbox\Toolbox;
16+
use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory;
17+
use Symfony\AI\Platform\Message\Message;
18+
use Symfony\AI\Platform\Message\MessageBag;
19+
20+
require_once dirname(__DIR__).'/bootstrap.php';
21+
22+
$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client());
23+
24+
$toolbox = new Toolbox([new Clock()], logger: logger());
25+
$processor = new AgentProcessor($toolbox);
26+
$agent = new Agent($platform, env('OLLAMA_LLM'), [$processor], [$processor]);
27+
28+
$messages = new MessageBag(Message::ofUser('What time is it?'));
29+
30+
$result = $agent->call($messages, ['stream' => true]);
31+
32+
foreach ($result->getContent() as $chunk) {
33+
echo $chunk->getContent();
34+
}
35+
36+
echo \PHP_EOL;

src/platform/src/Bridge/Ollama/OllamaResultConverter.php

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ public function doConvertEmbeddings(array $data): ResultInterface
9595

9696
private function convertStream(ResponseInterface $result): \Generator
9797
{
98+
$toolCalls = [];
9899
foreach ((new EventSourceHttpClient())->stream($result) as $chunk) {
99100
if ($chunk instanceof FirstChunk || $chunk instanceof LastChunk) {
100101
continue;
@@ -106,6 +107,14 @@ private function convertStream(ResponseInterface $result): \Generator
106107
throw new RuntimeException('Failed to decode JSON: '.$e->getMessage());
107108
}
108109

110+
if ($this->streamIsToolCall($data)) {
111+
$toolCalls = $this->convertStreamToToolCalls($toolCalls, $data);
112+
}
113+
114+
if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) {
115+
yield new ToolCallResult(...$toolCalls);
116+
}
117+
109118
yield new OllamaMessageChunk(
110119
$data['model'],
111120
new \DateTimeImmutable($data['created_at']),
@@ -114,4 +123,39 @@ private function convertStream(ResponseInterface $result): \Generator
114123
);
115124
}
116125
}
126+
127+
/**
128+
* @param array<string, mixed> $toolCalls
129+
* @param array<string, mixed> $data
130+
*
131+
* @return array<ToolCall>
132+
*/
133+
private function convertStreamToToolCalls(array $toolCalls, array $data): array
134+
{
135+
if (!isset($data['message']['tool_calls'])) {
136+
return $toolCalls;
137+
}
138+
139+
foreach ($data['message']['tool_calls'] ?? [] as $id => $toolCall) {
140+
$toolCalls[] = new ToolCall($id, $toolCall['function']['name'], $toolCall['function']['arguments']);
141+
}
142+
143+
return $toolCalls;
144+
}
145+
146+
/**
147+
* @param array<string, mixed> $data
148+
*/
149+
private function streamIsToolCall(array $data): bool
150+
{
151+
return isset($data['message']['tool_calls']);
152+
}
153+
154+
/**
155+
* @param array<string, mixed> $data^
156+
*/
157+
private function isToolCallsStreamFinished(array $data): bool
158+
{
159+
return isset($data['done']) && true === $data['done'];
160+
}
117161
}

0 commit comments

Comments
 (0)