Skip to content

Commit 8541a99

Browse files
committed
Ease toolbox config in AiBundle
1 parent b43f1bd commit 8541a99

File tree

5 files changed

+48
-66
lines changed

5 files changed

+48
-66
lines changed

src/ai-bundle/config/services.php

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616
use Symfony\AI\Agent\StructuredOutput\ResponseFormatFactoryInterface;
1717
use Symfony\AI\Agent\Toolbox\AgentProcessor as ToolProcessor;
1818
use Symfony\AI\Agent\Toolbox\Toolbox;
19-
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
2019
use Symfony\AI\Agent\Toolbox\ToolCallArgumentResolver;
2120
use Symfony\AI\Agent\Toolbox\ToolFactory\AbstractToolFactory;
2221
use Symfony\AI\Agent\Toolbox\ToolFactory\ReflectionToolFactory;
2322
use Symfony\AI\Agent\Toolbox\ToolResultConverter;
2423
use Symfony\AI\AiBundle\Command\AgentCallCommand;
2524
use Symfony\AI\AiBundle\Command\PlatformInvokeCommand;
2625
use Symfony\AI\AiBundle\Profiler\DataCollector;
27-
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
2826
use Symfony\AI\AiBundle\Security\EventListener\IsGrantedToolAttributeListener;
2927
use Symfony\AI\Chat\Command\DropStoreCommand as DropMessageStoreCommand;
3028
use Symfony\AI\Chat\Command\SetupStoreCommand as SetupMessageStoreCommand;
@@ -130,16 +128,12 @@
130128
->set('ai.toolbox.abstract', Toolbox::class)
131129
->abstract()
132130
->args([
133-
abstract_arg('Collection of tools'),
131+
tagged_iterator('ai.tool'),
134132
service('ai.tool_factory'),
135133
service('ai.tool_call_argument_resolver'),
136134
service('logger')->ignoreOnInvalid(),
137135
service('event_dispatcher')->nullOnInvalid(),
138136
])
139-
->set('ai.toolbox', Toolbox::class)
140-
->parent('ai.toolbox.abstract')
141-
->arg('index_0', tagged_iterator('ai.tool'))
142-
->alias(ToolboxInterface::class, 'ai.toolbox')
143137
->set('ai.tool_factory.abstract', AbstractToolFactory::class)
144138
->abstract()
145139
->args([
@@ -164,9 +158,6 @@
164158
service('event_dispatcher')->nullOnInvalid(),
165159
false,
166160
])
167-
->set('ai.tool.agent_processor', ToolProcessor::class)
168-
->parent('ai.tool.agent_processor.abstract')
169-
->arg('index_0', service('ai.toolbox'))
170161
->set('ai.security.is_granted_attribute_listener', IsGrantedToolAttributeListener::class)
171162
->args([
172163
service('security.authorization_checker'),
@@ -178,16 +169,9 @@
178169
->set('ai.data_collector', DataCollector::class)
179170
->args([
180171
tagged_iterator('ai.traceable_platform'),
181-
service('ai.toolbox'),
182172
tagged_iterator('ai.traceable_toolbox'),
183173
])
184174
->tag('data_collector')
185-
->set('ai.traceable_toolbox', TraceableToolbox::class)
186-
->decorate('ai.toolbox', priority: -1)
187-
->args([
188-
service('.inner'),
189-
])
190-
->tag('ai.traceable_toolbox')
191175

192176
// token usage processors
193177
->set('ai.platform.token_usage_processor.anthropic', AnthropicTokenOutputProcessor::class)

src/ai-bundle/src/AiBundle.php

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,37 @@ private function processAgentConfig(string $name, array $config, ContainerBuilde
589589
->setArgument(0, new Reference($config['platform']))
590590
->setArgument(1, $config['model']);
591591

592-
// TOOL & PROCESSOR
592+
// TOOLBOX
593593
if ($config['tools']['enabled']) {
594-
// Create specific toolbox and process if tools are explicitly defined
594+
// Setup toolbox for agent
595+
$toolboxDefinition = (new ChildDefinition('ai.toolbox.abstract'))
596+
->replaceArgument(1, new Reference('ai.toolbox.'.$name.'.chain_factory'))
597+
->addTag('ai.toolbox', ['name' => $name]);
598+
$container->setDefinition('ai.toolbox.'.$name, $toolboxDefinition);
599+
600+
if ($config['fault_tolerant_toolbox']) {
601+
$container->setDefinition('ai.fault_tolerant_toolbox.'.$name, new Definition(FaultTolerantToolbox::class))
602+
->setArguments([new Reference('.inner')])
603+
->setDecoratedService('ai.toolbox.'.$name);
604+
}
605+
606+
if ($container->getParameter('kernel.debug')) {
607+
$traceableToolboxDefinition = (new Definition('ai.traceable_toolbox.'.$name))
608+
->setClass(TraceableToolbox::class)
609+
->setArguments([new Reference('.inner')])
610+
->setDecoratedService('ai.toolbox.'.$name)
611+
->addTag('ai.traceable_toolbox');
612+
$container->setDefinition('ai.traceable_toolbox.'.$name, $traceableToolboxDefinition);
613+
}
614+
615+
$toolProcessorDefinition = (new ChildDefinition('ai.tool.agent_processor.abstract'))
616+
->replaceArgument(0, new Reference('ai.toolbox.'.$name));
617+
618+
$container->setDefinition('ai.tool.agent_processor.'.$name, $toolProcessorDefinition)
619+
->addTag('ai.agent.input_processor', ['agent' => $agentId, 'priority' => -10])
620+
->addTag('ai.agent.output_processor', ['agent' => $agentId, 'priority' => -10]);
621+
622+
// Define specific list of tools if are explicitly defined
595623
if ([] !== $config['tools']['services']) {
596624
$memoryFactoryDefinition = new ChildDefinition('ai.tool_factory.abstract');
597625
$memoryFactoryDefinition->setClass(MemoryToolFactory::class);
@@ -620,42 +648,7 @@ private function processAgentConfig(string $name, array $config, ContainerBuilde
620648
$tools[] = $reference;
621649
}
622650

623-
$toolboxDefinition = (new ChildDefinition('ai.toolbox.abstract'))
624-
->replaceArgument(0, $tools)
625-
->replaceArgument(1, new Reference('ai.toolbox.'.$name.'.chain_factory'));
626-
$container->setDefinition('ai.toolbox.'.$name, $toolboxDefinition);
627-
628-
if ($config['fault_tolerant_toolbox']) {
629-
$container->setDefinition('ai.fault_tolerant_toolbox.'.$name, new Definition(FaultTolerantToolbox::class))
630-
->setArguments([new Reference('.inner')])
631-
->setDecoratedService('ai.toolbox.'.$name);
632-
}
633-
634-
if ($container->getParameter('kernel.debug')) {
635-
$traceableToolboxDefinition = (new Definition('ai.traceable_toolbox.'.$name))
636-
->setClass(TraceableToolbox::class)
637-
->setArguments([new Reference('.inner')])
638-
->setDecoratedService('ai.toolbox.'.$name)
639-
->addTag('ai.traceable_toolbox');
640-
$container->setDefinition('ai.traceable_toolbox.'.$name, $traceableToolboxDefinition);
641-
}
642-
643-
$toolProcessorDefinition = (new ChildDefinition('ai.tool.agent_processor.abstract'))
644-
->replaceArgument(0, new Reference('ai.toolbox.'.$name));
645-
646-
$container->setDefinition('ai.tool.agent_processor.'.$name, $toolProcessorDefinition)
647-
->addTag('ai.agent.input_processor', ['agent' => $agentId, 'priority' => -10])
648-
->addTag('ai.agent.output_processor', ['agent' => $agentId, 'priority' => -10]);
649-
} else {
650-
if ($config['fault_tolerant_toolbox'] && !$container->hasDefinition('ai.fault_tolerant_toolbox')) {
651-
$container->setDefinition('ai.fault_tolerant_toolbox', new Definition(FaultTolerantToolbox::class))
652-
->setArguments([new Reference('.inner')])
653-
->setDecoratedService('ai.toolbox');
654-
}
655-
656-
$container->getDefinition('ai.tool.agent_processor')
657-
->addTag('ai.agent.input_processor', ['agent' => $agentId, 'priority' => -10])
658-
->addTag('ai.agent.output_processor', ['agent' => $agentId, 'priority' => -10]);
651+
$toolboxDefinition->replaceArgument(0, $tools);
659652
}
660653
}
661654

src/ai-bundle/src/Profiler/DataCollector.php

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
namespace Symfony\AI\AiBundle\Profiler;
1313

14-
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
1514
use Symfony\AI\Agent\Toolbox\ToolResult;
1615
use Symfony\AI\Platform\Tool\Tool;
1716
use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector;
@@ -42,7 +41,6 @@ final class DataCollector extends AbstractDataCollector implements LateDataColle
4241
*/
4342
public function __construct(
4443
iterable $platforms,
45-
private readonly ToolboxInterface $defaultToolBox,
4644
iterable $toolboxes,
4745
) {
4846
$this->platforms = $platforms instanceof \Traversable ? iterator_to_array($platforms) : $platforms;
@@ -57,7 +55,7 @@ public function collect(Request $request, Response $response, ?\Throwable $excep
5755
public function lateCollect(): void
5856
{
5957
$this->data = [
60-
'tools' => $this->defaultToolBox->getTools(),
58+
'tools' => $this->getAllTools(),
6159
'platform_calls' => array_merge(...array_map($this->awaitCallResults(...), $this->platforms)),
6260
'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)),
6361
];
@@ -92,6 +90,14 @@ public function getToolCalls(): array
9290
return $this->data['tool_calls'] ?? [];
9391
}
9492

93+
/**
94+
* @return Tool[]
95+
*/
96+
private function getAllTools(): array
97+
{
98+
return array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->getTools(), $this->toolboxes));
99+
}
100+
95101
/**
96102
* @return array{
97103
* model: string,

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ public function testFaultTolerantDefaultToolbox(bool $enabled)
271271
],
272272
]);
273273

274-
$this->assertSame($enabled, $container->hasDefinition('ai.fault_tolerant_toolbox'));
274+
$this->assertSame($enabled, $container->hasDefinition('ai.fault_tolerant_toolbox.my_agent'));
275275
}
276276

277277
public function testAgentsCanBeRegisteredAsTools()
@@ -619,23 +619,23 @@ public function testMultipleAgentsWithProcessors()
619619
}
620620

621621
#[TestDox('Processors work correctly when using the default toolbox')]
622-
public function testDefaultToolboxProcessorTags()
622+
public function testToolboxWithoutExplicitToolsDefined()
623623
{
624624
$container = $this->buildContainer([
625625
'ai' => [
626626
'agent' => [
627-
'agent_with_default_toolbox' => [
627+
'agent_with_tools' => [
628628
'model' => 'gpt-4',
629629
'tools' => true,
630630
],
631631
],
632632
],
633633
]);
634634

635-
$agentId = 'ai.agent.agent_with_default_toolbox';
635+
$agentId = 'ai.agent.agent_with_tools';
636636

637637
// When using default toolbox, the ai.tool.agent_processor service gets the tags
638-
$defaultToolProcessor = $container->getDefinition('ai.tool.agent_processor');
638+
$defaultToolProcessor = $container->getDefinition('ai.tool.agent_processor.agent_with_tools');
639639
$inputTags = $defaultToolProcessor->getTag('ai.agent.input_processor');
640640
$outputTags = $defaultToolProcessor->getTag('ai.agent.output_processor');
641641

src/ai-bundle/tests/Profiler/DataCollectorTest.php

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
namespace Symfony\AI\AiBundle\Tests\Profiler;
1313

1414
use PHPUnit\Framework\TestCase;
15-
use Symfony\AI\Agent\Toolbox\ToolboxInterface;
1615
use Symfony\AI\AiBundle\Profiler\DataCollector;
1716
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
1817
use Symfony\AI\Platform\Message\Content\Text;
@@ -39,7 +38,7 @@ public function testCollectsDataForNonStreamingResponse()
3938
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => false]);
4039
$this->assertSame('Assistant response', $result->asText());
4140

42-
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), []);
41+
$dataCollector = new DataCollector([$traceablePlatform], []);
4342
$dataCollector->lateCollect();
4443

4544
$this->assertCount(1, $dataCollector->getPlatformCalls());
@@ -63,7 +62,7 @@ public function testCollectsDataForStreamingResponse()
6362
$result = $traceablePlatform->invoke('gpt-4o', $messageBag, ['stream' => true]);
6463
$this->assertSame('Assistant response', implode('', iterator_to_array($result->asStream())));
6564

66-
$dataCollector = new DataCollector([$traceablePlatform], $this->createStub(ToolboxInterface::class), []);
65+
$dataCollector = new DataCollector([$traceablePlatform], []);
6766
$dataCollector->lateCollect();
6867

6968
$this->assertCount(1, $dataCollector->getPlatformCalls());

0 commit comments

Comments
 (0)