Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 75 additions & 27 deletions src/Providers/Models/DTO/ModelConfig.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
use WordPress\AiClient\Files\Enums\FileTypeEnum;
use WordPress\AiClient\Files\Enums\MediaOrientationEnum;
use WordPress\AiClient\Messages\Enums\ModalityEnum;
use WordPress\AiClient\Tools\DTO\Tool;
use WordPress\AiClient\Tools\DTO\FunctionDeclaration;
use WordPress\AiClient\Tools\DTO\WebSearch;

/**
* Represents configuration for an AI model.
Expand All @@ -20,7 +21,8 @@
*
* @since n.e.x.t
*
* @phpstan-import-type ToolArrayShape from Tool
* @phpstan-import-type FunctionDeclarationArrayShape from FunctionDeclaration
* @phpstan-import-type WebSearchArrayShape from WebSearch
*
* @phpstan-type ModelConfigArrayShape array{
* outputModalities?: list<string>,
Expand All @@ -35,7 +37,8 @@
* frequencyPenalty?: float,
* logprobs?: bool,
* topLogprobs?: int,
* tools?: list<ToolArrayShape>,
* functionDeclarations?: list<FunctionDeclarationArrayShape>,
* webSearch?: WebSearchArrayShape,
* outputFileType?: string,
* outputMimeType?: string,
* outputSchema?: array<string, mixed>,
Expand All @@ -60,7 +63,8 @@ class ModelConfig extends AbstractDataTransferObject
public const KEY_FREQUENCY_PENALTY = 'frequencyPenalty';
public const KEY_LOGPROBS = 'logprobs';
public const KEY_TOP_LOGPROBS = 'topLogprobs';
public const KEY_TOOLS = 'tools';
public const KEY_FUNCTION_DECLARATIONS = 'functionDeclarations';
public const KEY_WEB_SEARCH = 'webSearch';
public const KEY_OUTPUT_FILE_TYPE = 'outputFileType';
public const KEY_OUTPUT_MIME_TYPE = 'outputMimeType';
public const KEY_OUTPUT_SCHEMA = 'outputSchema';
Expand Down Expand Up @@ -129,9 +133,14 @@ class ModelConfig extends AbstractDataTransferObject
protected ?int $topLogprobs = null;

/**
* @var list<Tool>|null Tools available to the model.
* @var list<FunctionDeclaration>|null Function declarations available to the model.
*/
protected ?array $tools = null;
protected ?array $functionDeclarations = null;

/**
* @var WebSearch|null Web search configuration for the model.
*/
protected ?WebSearch $webSearch = null;

/**
* @var FileTypeEnum|null Output file type.
Expand Down Expand Up @@ -464,33 +473,57 @@ public function getTopLogprobs(): ?int
}

/**
* Sets the tools.
* Sets the function declarations.
*
* @since n.e.x.t
*
* @param list<Tool> $tools The tools.
* @param list<FunctionDeclaration> $function_declarations The function declarations.
*
* @throws InvalidArgumentException If the array is not a list.
*/
public function setTools(array $tools): void
public function setFunctionDeclarations(array $function_declarations): void
{
if (!array_is_list($tools)) {
throw new InvalidArgumentException('Tools must be a list array.');
if (!array_is_list($function_declarations)) {
throw new InvalidArgumentException('Function declarations must be a list array.');
}

$this->tools = $tools;
$this->functionDeclarations = $function_declarations;
}

/**
* Gets the function declarations.
*
* @since n.e.x.t
*
* @return list<FunctionDeclaration>|null The function declarations.
*/
public function getFunctionDeclarations(): ?array
{
return $this->functionDeclarations;
}

/**
* Gets the tools.
* Sets the web search configuration.
*
* @since n.e.x.t
*
* @return list<Tool>|null The tools.
* @param WebSearch $web_search The web search configuration.
*/
public function getTools(): ?array
public function setWebSearch(WebSearch $web_search): void
{
return $this->tools;
$this->webSearch = $web_search;
}

/**
* Gets the web search configuration.
*
* @since n.e.x.t
*
* @return WebSearch|null The web search configuration.
*/
public function getWebSearch(): ?WebSearch
{
return $this->webSearch;
}

/**
Expand Down Expand Up @@ -738,11 +771,12 @@ public static function getJsonSchema(): array
'minimum' => 1,
'description' => 'Number of top log probabilities to return.',
],
self::KEY_TOOLS => [
self::KEY_FUNCTION_DECLARATIONS => [
'type' => 'array',
'items' => Tool::getJsonSchema(),
'description' => 'Tools available to the model.',
'items' => FunctionDeclaration::getJsonSchema(),
'description' => 'Function declarations available to the model.',
],
self::KEY_WEB_SEARCH => WebSearch::getJsonSchema(),
self::KEY_OUTPUT_FILE_TYPE => [
'type' => 'string',
'enum' => FileTypeEnum::getValues(),
Expand Down Expand Up @@ -841,10 +875,17 @@ static function (ModalityEnum $modality): string {
$data[self::KEY_TOP_LOGPROBS] = $this->topLogprobs;
}

if ($this->tools !== null) {
$data[self::KEY_TOOLS] = array_map(static function (Tool $tool): array {
return $tool->toArray();
}, $this->tools);
if ($this->functionDeclarations !== null) {
$data[self::KEY_FUNCTION_DECLARATIONS] = array_map(
static function (FunctionDeclaration $function_declaration): array {
return $function_declaration->toArray();
},
$this->functionDeclarations
);
}

if ($this->webSearch !== null) {
$data[self::KEY_WEB_SEARCH] = $this->webSearch->toArray();
}

if ($this->outputFileType !== null) {
Expand Down Expand Up @@ -932,10 +973,17 @@ public static function fromArray(array $array): self
$config->setTopLogprobs($array[self::KEY_TOP_LOGPROBS]);
}

if (isset($array[self::KEY_TOOLS])) {
$config->setTools(array_map(static function (array $toolData): Tool {
return Tool::fromArray($toolData);
}, $array[self::KEY_TOOLS]));
if (isset($array[self::KEY_FUNCTION_DECLARATIONS])) {
$config->setFunctionDeclarations(array_map(
static function (array $function_declaration_data): FunctionDeclaration {
return FunctionDeclaration::fromArray($function_declaration_data);
},
$array[self::KEY_FUNCTION_DECLARATIONS]
));
}

if (isset($array[self::KEY_WEB_SEARCH])) {
$config->setWebSearch(WebSearch::fromArray($array[self::KEY_WEB_SEARCH]));
}

if (isset($array[self::KEY_OUTPUT_FILE_TYPE])) {
Expand Down
193 changes: 0 additions & 193 deletions src/Tools/DTO/Tool.php

This file was deleted.

Loading