Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use InvalidArgumentException;
use RuntimeException;
use WordPress\AiClient\Files\DTO\File;
use WordPress\AiClient\Files\Enums\FileTypeEnum;
use WordPress\AiClient\Messages\DTO\Message;
use WordPress\AiClient\Messages\DTO\MessagePart;
use WordPress\AiClient\Messages\DTO\UserMessage;
Expand All @@ -24,7 +25,9 @@
use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface;
use WordPress\AiClient\Providers\ProviderRegistry;
use WordPress\AiClient\Results\DTO\GenerativeAiResult;
use WordPress\AiClient\Tools\DTO\FunctionDeclaration;
use WordPress\AiClient\Tools\DTO\FunctionResponse;
use WordPress\AiClient\Tools\DTO\WebSearch;

/**
* Fluent builder for constructing AI prompts.
Expand Down Expand Up @@ -354,6 +357,86 @@ public function usingCandidateCount(int $candidateCount): self
return $this;
}

/**
* Sets the function declarations available to the model.
*
* @since n.e.x.t
*
* @param FunctionDeclaration ...$functionDeclarations The function declarations.
* @return self
*/
public function usingFunctionDeclarations(FunctionDeclaration ...$functionDeclarations): self
{
$this->modelConfig->setFunctionDeclarations($functionDeclarations);
return $this;
}

/**
* Sets the presence penalty for generation.
*
* @since n.e.x.t
*
* @param float $presencePenalty The presence penalty value.
* @return self
*/
public function usingPresencePenalty(float $presencePenalty): self
{
$this->modelConfig->setPresencePenalty($presencePenalty);
return $this;
}

/**
* Sets the frequency penalty for generation.
*
* @since n.e.x.t
*
* @param float $frequencyPenalty The frequency penalty value.
* @return self
*/
public function usingFrequencyPenalty(float $frequencyPenalty): self
{
$this->modelConfig->setFrequencyPenalty($frequencyPenalty);
return $this;
}

/**
* Sets the web search configuration.
*
* @since n.e.x.t
*
* @param WebSearch $webSearch The web search configuration.
* @return self
*/
public function usingWebSearch(WebSearch $webSearch): self
{
$this->modelConfig->setWebSearch($webSearch);
return $this;
}

/**
* Sets the top log probabilities configuration.
*
* If $topLogprobs is null, enables log probabilities.
* If $topLogprobs has a value, enables log probabilities and sets the number of top log probabilities to return.
*
* @since n.e.x.t
*
* @param int|null $topLogprobs The number of top log probabilities to return, or null to enable log probabilities.
* @return self
*/
public function usingTopLogprobs(?int $topLogprobs = null): self
{
// Always enable log probabilities
$this->modelConfig->setLogprobs(true);

// If a specific number is provided, set it
if ($topLogprobs !== null) {
$this->modelConfig->setTopLogprobs($topLogprobs);
}

return $this;
}

/**
* Sets the output MIME type.
*
Expand Down Expand Up @@ -396,6 +479,20 @@ public function asOutputModalities(ModalityEnum ...$modalities): self
return $this;
}

/**
* Sets the output file type.
*
* @since n.e.x.t
*
* @param FileTypeEnum $fileType The output file type.
* @return self
*/
public function asOutputFileType(FileTypeEnum $fileType): self
{
$this->modelConfig->setOutputFileType($fileType);
return $this;
}

/**
* Configures the prompt for JSON response output.
*
Expand Down
246 changes: 246 additions & 0 deletions tests/unit/Builders/PromptBuilderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use RuntimeException;
use WordPress\AiClient\Builders\PromptBuilder;
use WordPress\AiClient\Files\DTO\File;
use WordPress\AiClient\Files\Enums\FileTypeEnum;
use WordPress\AiClient\Messages\DTO\Message;
use WordPress\AiClient\Messages\DTO\MessagePart;
use WordPress\AiClient\Messages\DTO\ModelMessage;
Expand All @@ -31,7 +32,9 @@
use WordPress\AiClient\Results\DTO\TokenUsage;
use WordPress\AiClient\Results\Enums\FinishReasonEnum;
use WordPress\AiClient\Tests\traits\MockModelCreationTrait;
use WordPress\AiClient\Tools\DTO\FunctionDeclaration;
use WordPress\AiClient\Tools\DTO\FunctionResponse;
use WordPress\AiClient\Tools\DTO\WebSearch;

/**
* @covers \WordPress\AiClient\Builders\PromptBuilder
Expand Down Expand Up @@ -2787,4 +2790,247 @@ public function testFluentInterfaceWithProvider(): void
$this->assertEquals(500, $config->getMaxTokens());
$this->assertEquals(0.7, $config->getTemperature());
}

/**
* Tests usingFunctionDeclarations method.
*
* @return void
*/
public function testUsingFunctionDeclarations(): void
{
$builder = new PromptBuilder($this->registry);

$functionDeclaration1 = new FunctionDeclaration(
'test_function',
'A test function',
['param1' => ['type' => 'string']]
);
$functionDeclaration2 = new FunctionDeclaration(
'another_function',
'Another test function',
['param2' => ['type' => 'integer']]
);

$result = $builder->usingFunctionDeclarations($functionDeclaration1, $functionDeclaration2);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$functionDeclarations = $config->getFunctionDeclarations();
$this->assertIsArray($functionDeclarations);
$this->assertCount(2, $functionDeclarations);
$this->assertSame($functionDeclaration1, $functionDeclarations[0]);
$this->assertSame($functionDeclaration2, $functionDeclarations[1]);
}

/**
* Tests usingPresencePenalty method.
*
* @return void
*/
public function testUsingPresencePenalty(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->usingPresencePenalty(0.5);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$this->assertEquals(0.5, $config->getPresencePenalty());
}

/**
* Tests usingFrequencyPenalty method.
*
* @return void
*/
public function testUsingFrequencyPenalty(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->usingFrequencyPenalty(0.8);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$this->assertEquals(0.8, $config->getFrequencyPenalty());
}

/**
* Tests usingWebSearch method.
*
* @return void
*/
public function testUsingWebSearch(): void
{
$builder = new PromptBuilder($this->registry);

$webSearch = new WebSearch(
['allowed.com', 'trusted.org'],
['blocked.com', 'spam.net']
);

$result = $builder->usingWebSearch($webSearch);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$configWebSearch = $config->getWebSearch();
$this->assertNotNull($configWebSearch);
$this->assertSame($webSearch, $configWebSearch);
$this->assertEquals(['allowed.com', 'trusted.org'], $configWebSearch->getAllowedDomains());
$this->assertEquals(['blocked.com', 'spam.net'], $configWebSearch->getDisallowedDomains());
}

/**
* Tests asOutputFileType method.
*
* @return void
*/
public function testAsOutputFileType(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->asOutputFileType(FileTypeEnum::inline());

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$outputFileType = $config->getOutputFileType();
$this->assertNotNull($outputFileType);
$this->assertTrue($outputFileType->isInline());
}

/**
* Tests asOutputFileType method with remote file type.
*
* @return void
*/
public function testAsOutputFileTypeWithRemote(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->asOutputFileType(FileTypeEnum::remote());

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$outputFileType = $config->getOutputFileType();
$this->assertNotNull($outputFileType);
$this->assertTrue($outputFileType->isRemote());
}

/**
* Tests usingTopLogprobs method with null value (only enables logprobs).
*
* @return void
*/
public function testUsingTopLogprobsWithNull(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->usingTopLogprobs();

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$this->assertTrue($config->getLogprobs());
$this->assertNull($config->getTopLogprobs());
}

/**
* Tests usingTopLogprobs method with specific value.
*
* @return void
*/
public function testUsingTopLogprobsWithValue(): void
{
$builder = new PromptBuilder($this->registry);
$result = $builder->usingTopLogprobs(5);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$this->assertTrue($config->getLogprobs());
$this->assertEquals(5, $config->getTopLogprobs());
}

/**
* Tests method chaining with multiple new methods.
*
* @return void
*/
public function testMethodChainingWithNewMethods(): void
{
$builder = new PromptBuilder($this->registry);

$functionDeclaration = new FunctionDeclaration(
'test_function',
'A test function',
['param1' => ['type' => 'string']]
);

$webSearch = new WebSearch(['allowed.com'], ['blocked.com']);

$result = $builder
->withText('Test prompt')
->usingPresencePenalty(0.5)
->usingFrequencyPenalty(0.7)
->usingFunctionDeclarations($functionDeclaration)
->usingWebSearch($webSearch)
->asOutputFileType(FileTypeEnum::inline())
->usingTopLogprobs(3);

$this->assertSame($builder, $result);

$reflection = new \ReflectionClass($builder);
$configProperty = $reflection->getProperty('modelConfig');
$configProperty->setAccessible(true);
/** @var ModelConfig $config */
$config = $configProperty->getValue($builder);

$this->assertEquals(0.5, $config->getPresencePenalty());
$this->assertEquals(0.7, $config->getFrequencyPenalty());
$this->assertCount(1, $config->getFunctionDeclarations());
$this->assertNotNull($config->getWebSearch());
$this->assertTrue($config->getOutputFileType()->isInline());
$this->assertTrue($config->getLogprobs());
$this->assertEquals(3, $config->getTopLogprobs());
}
}