Skip to content

Commit ddd2c3d

Browse files
committed
feat: adds support for specifying provider in PromptBuilder
1 parent e5e1f18 commit ddd2c3d

File tree

2 files changed

+193
-16
lines changed

2 files changed

+193
-16
lines changed

src/Builders/PromptBuilder.php

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class PromptBuilder
5757
*/
5858
protected ?ModelInterface $model = null;
5959

60+
/**
61+
* @var string|null The provider ID or class name.
62+
*/
63+
protected ?string $providerIdOrClassName = null;
64+
6065
/**
6166
* @var ModelConfig The model configuration.
6267
*/
@@ -198,6 +203,20 @@ public function usingModel(ModelInterface $model): self
198203
return $this;
199204
}
200205

206+
/**
207+
* Sets the provider to use for generation.
208+
*
209+
* @since n.e.x.t
210+
*
211+
* @param string $providerIdOrClassName The provider ID or class name.
212+
* @return self
213+
*/
214+
public function usingProvider(string $providerIdOrClassName): self
215+
{
216+
$this->providerIdOrClassName = $providerIdOrClassName;
217+
return $this;
218+
}
219+
201220
/**
202221
* Sets the system instruction.
203222
*
@@ -930,28 +949,50 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
930949
}
931950

932951
// Find a suitable model based on requirements
933-
$modelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
952+
if ($this->providerIdOrClassName === null) {
953+
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
934954

935-
if (empty($modelsMetadata)) {
936-
throw new InvalidArgumentException(
937-
'No models found that support the required capabilities and options for this prompt. ' .
938-
'Required capabilities: ' . implode(', ', array_map(function ($cap) {
939-
return $cap->value;
940-
}, $requirements->getRequiredCapabilities())) .
941-
'. Required options: ' . implode(', ', array_map(function ($opt) {
942-
return $opt->getName()->value . '=' . json_encode($opt->getValue());
943-
}, $requirements->getRequiredOptions()))
955+
if (empty($providerModelsMetadata)) {
956+
throw new InvalidArgumentException(
957+
'No models found that support the required capabilities and options for this prompt. ' .
958+
'Required capabilities: ' . implode(', ', array_map(function ($cap) {
959+
return $cap->value;
960+
}, $requirements->getRequiredCapabilities())) .
961+
'. Required options: ' . implode(', ', array_map(function ($opt) {
962+
return $opt->getName()->value . '=' . json_encode($opt->getValue());
963+
}, $requirements->getRequiredOptions()))
964+
);
965+
}
966+
967+
$firstProviderModels = $providerModelsMetadata[0];
968+
$provider = $firstProviderModels->getProvider()->getId();
969+
$modelMetadata = $firstProviderModels->getModels()[0];
970+
} else {
971+
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
972+
$this->providerIdOrClassName,
973+
$requirements
944974
);
945-
}
946975

947-
// Get the first available model from the first provider
948-
$firstProviderModels = $modelsMetadata[0];
949-
$firstModelMetadata = $firstProviderModels->getModels()[0];
976+
if (empty($modelsMetadata)) {
977+
throw new InvalidArgumentException(
978+
'No models found that support the required capabilities and options for this prompt. ' .
979+
'Required capabilities: ' . implode(', ', array_map(function ($cap) {
980+
return $cap->value;
981+
}, $requirements->getRequiredCapabilities())) .
982+
'. Required options: ' . implode(', ', array_map(function ($opt) {
983+
return $opt->getName()->value . '=' . json_encode($opt->getValue());
984+
}, $requirements->getRequiredOptions()))
985+
);
986+
}
987+
988+
$provider = $this->providerIdOrClassName;
989+
$modelMetadata = $modelsMetadata[0];
990+
}
950991

951992
// Get the model instance from the provider
952993
return $this->registry->getProviderModel(
953-
$firstProviderModels->getProvider()->getId(),
954-
$firstModelMetadata->getId(),
994+
$provider,
995+
$modelMetadata->getId(),
955996
$this->modelConfig
956997
);
957998
}

tests/unit/Builders/PromptBuilderTest.php

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
2020
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
2121
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
22+
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;
2223
use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface;
2324
use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface;
2425
use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface;
@@ -613,6 +614,26 @@ public function testUsingModel(): void
613614
$this->assertSame($model, $actualModel);
614615
}
615616

617+
/**
618+
* Tests usingProvider method.
619+
*
620+
* @return void
621+
*/
622+
public function testUsingProvider(): void
623+
{
624+
$builder = new PromptBuilder($this->registry);
625+
$result = $builder->usingProvider('test-provider');
626+
627+
$this->assertSame($builder, $result);
628+
629+
$reflection = new \ReflectionClass($builder);
630+
$providerProperty = $reflection->getProperty('providerIdOrClassName');
631+
$providerProperty->setAccessible(true);
632+
633+
$actualProvider = $providerProperty->getValue($builder);
634+
$this->assertEquals('test-provider', $actualProvider);
635+
}
636+
616637
/**
617638
* Tests usingSystemInstruction method.
618639
*
@@ -2462,4 +2483,119 @@ public function testIsSupportedForSpeechGeneration(): void
24622483

24632484
$this->assertTrue($builder->isSupportedForSpeechGeneration());
24642485
}
2486+
2487+
/**
2488+
* Tests generateResult with provider specified.
2489+
*
2490+
* @return void
2491+
*/
2492+
public function testGenerateResultWithProvider(): void
2493+
{
2494+
$result = $this->createMock(GenerativeAiResult::class);
2495+
2496+
$modelMetadata = $this->createMock(ModelMetadata::class);
2497+
$modelMetadata->method('getId')->willReturn('provider-model');
2498+
$modelMetadata->method('meetsRequirements')->willReturn(true);
2499+
2500+
$model = $this->createTextGenerationModel($modelMetadata, $result);
2501+
2502+
// Mock the registry to return the model when provider is specified
2503+
$this->registry->expects($this->once())
2504+
->method('findProviderModelsMetadataForSupport')
2505+
->with('test-provider', $this->isInstanceOf(ModelRequirements::class))
2506+
->willReturn([$modelMetadata]);
2507+
2508+
$this->registry->expects($this->once())
2509+
->method('getProviderModel')
2510+
->with('test-provider', 'provider-model', $this->isInstanceOf(ModelConfig::class))
2511+
->willReturn($model);
2512+
2513+
$builder = new PromptBuilder($this->registry, 'Test prompt');
2514+
$builder->usingProvider('test-provider');
2515+
2516+
$actualResult = $builder->generateResult();
2517+
$this->assertSame($result, $actualResult);
2518+
}
2519+
2520+
/**
2521+
* Tests generateResult with provider but no suitable models.
2522+
*
2523+
* @return void
2524+
*/
2525+
public function testGenerateResultWithProviderNoModelsThrowsException(): void
2526+
{
2527+
// Mock the registry to return empty array when provider is specified
2528+
$this->registry->expects($this->once())
2529+
->method('findProviderModelsMetadataForSupport')
2530+
->with('test-provider', $this->isInstanceOf(ModelRequirements::class))
2531+
->willReturn([]);
2532+
2533+
$builder = new PromptBuilder($this->registry, 'Test prompt');
2534+
$builder->usingProvider('test-provider');
2535+
2536+
$this->expectException(InvalidArgumentException::class);
2537+
$this->expectExceptionMessage('No models found that support the required capabilities');
2538+
2539+
$builder->generateResult();
2540+
}
2541+
2542+
/**
2543+
* Tests that provider takes precedence when both provider and model are set.
2544+
*
2545+
* @return void
2546+
*/
2547+
public function testModelTakesPrecedenceOverProvider(): void
2548+
{
2549+
$result = $this->createMock(GenerativeAiResult::class);
2550+
2551+
$metadata = $this->createMock(ModelMetadata::class);
2552+
$metadata->method('getId')->willReturn('explicit-model');
2553+
$metadata->method('meetsRequirements')->willReturn(true);
2554+
2555+
$model = $this->createTextGenerationModel($metadata, $result);
2556+
2557+
// Registry should not be called when model is explicitly set
2558+
$this->registry->expects($this->never())
2559+
->method('findProviderModelsMetadataForSupport');
2560+
$this->registry->expects($this->never())
2561+
->method('getProviderModel');
2562+
2563+
$builder = new PromptBuilder($this->registry, 'Test prompt');
2564+
$builder->usingProvider('test-provider');
2565+
$builder->usingModel($model); // Model overrides provider
2566+
2567+
$actualResult = $builder->generateResult();
2568+
$this->assertSame($result, $actualResult);
2569+
}
2570+
2571+
/**
2572+
* Tests fluent interface with provider.
2573+
*
2574+
* @return void
2575+
*/
2576+
public function testFluentInterfaceWithProvider(): void
2577+
{
2578+
$builder = new PromptBuilder($this->registry, 'Initial text');
2579+
2580+
$result = $builder
2581+
->usingProvider('my-provider')
2582+
->withText(' Additional text')
2583+
->usingMaxTokens(500)
2584+
->usingTemperature(0.7);
2585+
2586+
$this->assertSame($builder, $result);
2587+
2588+
$reflection = new \ReflectionClass($builder);
2589+
2590+
$providerProperty = $reflection->getProperty('providerIdOrClassName');
2591+
$providerProperty->setAccessible(true);
2592+
$this->assertEquals('my-provider', $providerProperty->getValue($builder));
2593+
2594+
$configProperty = $reflection->getProperty('modelConfig');
2595+
$configProperty->setAccessible(true);
2596+
/** @var ModelConfig $config */
2597+
$config = $configProperty->getValue($builder);
2598+
$this->assertEquals(500, $config->getMaxTokens());
2599+
$this->assertEquals(0.7, $config->getTemperature());
2600+
}
24652601
}

0 commit comments

Comments
 (0)