Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 64 additions & 16 deletions src/Builders/PromptBuilder.php
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class PromptBuilder
*/
protected ?ModelInterface $model = null;

/**
* @var string|null The provider ID or class name.
*/
protected ?string $providerIdOrClassName = null;

/**
* @var ModelConfig The model configuration.
*/
Expand Down Expand Up @@ -198,6 +203,20 @@ public function usingModel(ModelInterface $model): self
return $this;
}

/**
* Sets the provider to use for generation.
*
* @since n.e.x.t
*
* @param string $providerIdOrClassName The provider ID or class name.
* @return self
*/
public function usingProvider(string $providerIdOrClassName): self
{
$this->providerIdOrClassName = $providerIdOrClassName;
return $this;
}

/**
* Sets the system instruction.
*
Expand Down Expand Up @@ -930,28 +949,57 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
}

// Find a suitable model based on requirements
$modelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
if ($this->providerIdOrClassName === null) {
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);

if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
'No models found that support the required capabilities and options for this prompt. ' .
'Required capabilities: ' . implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())) .
'. Required options: ' . implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
if (empty($providerModelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
}

$firstProviderModels = $providerModelsMetadata[0];
$provider = $firstProviderModels->getProvider()->getId();
$modelMetadata = $firstProviderModels->getModels()[0];
} else {
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
$this->providerIdOrClassName,
$requirements
);
}

// Get the first available model from the first provider
$firstProviderModels = $modelsMetadata[0];
$firstModelMetadata = $firstProviderModels->getModels()[0];
if (empty($modelsMetadata)) {
throw new InvalidArgumentException(
sprintf(
'No models found for %s that support the required capabilities and options for this prompt. ' .
'Required capabilities: %s. Required options: %s',
$this->providerIdOrClassName,
implode(', ', array_map(function ($cap) {
return $cap->value;
}, $requirements->getRequiredCapabilities())),
implode(', ', array_map(function ($opt) {
return $opt->getName()->value . '=' . json_encode($opt->getValue());
}, $requirements->getRequiredOptions()))
)
);
}

$provider = $this->providerIdOrClassName;
$modelMetadata = $modelsMetadata[0];
}

// Get the model instance from the provider
return $this->registry->getProviderModel(
$firstProviderModels->getProvider()->getId(),
$firstModelMetadata->getId(),
$provider,
$modelMetadata->getId(),
$this->modelConfig
);
}
Expand Down
22 changes: 21 additions & 1 deletion src/Results/Contracts/ResultInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

namespace WordPress\AiClient\Results\Contracts;

use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Results\DTO\TokenUsage;

/**
Expand Down Expand Up @@ -34,12 +36,30 @@ public function getId(): string;
*/
public function getTokenUsage(): TokenUsage;

/**
* Gets the provider metadata.
*
* @since n.e.x.t
*
* @return ProviderMetadata The provider metadata.
*/
public function getProviderMetadata(): ProviderMetadata;

/**
* Gets the model metadata.
*
* @since n.e.x.t
*
* @return ModelMetadata The model metadata.
*/
public function getModelMetadata(): ModelMetadata;

/**
* Gets provider-specific metadata.
*
* @since n.e.x.t
*
* @return array<string, mixed> Provider metadata.
*/
public function getProviderMetadata(): array;
public function getAdditionalData(): array;
}
98 changes: 84 additions & 14 deletions src/Results/DTO/GenerativeAiResult.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
use WordPress\AiClient\Common\AbstractDataTransferObject;
use WordPress\AiClient\Files\DTO\File;
use WordPress\AiClient\Messages\DTO\Message;
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Results\Contracts\ResultInterface;

/**
Expand All @@ -21,12 +23,16 @@
*
* @phpstan-import-type CandidateArrayShape from Candidate
* @phpstan-import-type TokenUsageArrayShape from TokenUsage
* @phpstan-import-type ProviderMetadataArrayShape from ProviderMetadata
* @phpstan-import-type ModelMetadataArrayShape from ModelMetadata
*
* @phpstan-type GenerativeAiResultArrayShape array{
* id: string,
* candidates: array<CandidateArrayShape>,
* tokenUsage: TokenUsageArrayShape,
* providerMetadata?: array<string, mixed>
* providerMetadata: ProviderMetadataArrayShape,
* modelMetadata: ModelMetadataArrayShape,
* additionalData?: array<string, mixed>
* }
*
* @extends AbstractDataTransferObject<GenerativeAiResultArrayShape>
Expand All @@ -37,6 +43,8 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
public const KEY_CANDIDATES = 'candidates';
public const KEY_TOKEN_USAGE = 'tokenUsage';
public const KEY_PROVIDER_METADATA = 'providerMetadata';
public const KEY_MODEL_METADATA = 'modelMetadata';
public const KEY_ADDITIONAL_DATA = 'additionalData';
/**
* @var string Unique identifier for this result.
*/
Expand All @@ -53,9 +61,19 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
private TokenUsage $tokenUsage;

/**
* @var array<string, mixed> Provider-specific metadata.
* @var ProviderMetadata Provider metadata.
*/
private array $providerMetadata;
private ProviderMetadata $providerMetadata;

/**
* @var ModelMetadata Model metadata.
*/
private ModelMetadata $modelMetadata;

/**
* @var array<string, mixed> Additional data.
*/
private array $additionalData;

/**
* Constructor.
Expand All @@ -65,11 +83,19 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
* @param string $id Unique identifier for this result.
* @param Candidate[] $candidates The generated candidates.
* @param TokenUsage $tokenUsage Token usage statistics.
* @param array<string, mixed> $providerMetadata Provider-specific metadata.
* @param ProviderMetadata $providerMetadata Provider metadata.
* @param ModelMetadata $modelMetadata Model metadata.
* @param array<string, mixed> $additionalData Additional data.
* @throws InvalidArgumentException If no candidates provided.
*/
public function __construct(string $id, array $candidates, TokenUsage $tokenUsage, array $providerMetadata = [])
{
public function __construct(
string $id,
array $candidates,
TokenUsage $tokenUsage,
ProviderMetadata $providerMetadata,
ModelMetadata $modelMetadata,
array $additionalData = []
) {
if (empty($candidates)) {
throw new InvalidArgumentException('At least one candidate must be provided');
}
Expand All @@ -78,6 +104,8 @@ public function __construct(string $id, array $candidates, TokenUsage $tokenUsag
$this->candidates = $candidates;
$this->tokenUsage = $tokenUsage;
$this->providerMetadata = $providerMetadata;
$this->modelMetadata = $modelMetadata;
$this->additionalData = $additionalData;
}

/**
Expand Down Expand Up @@ -113,15 +141,39 @@ public function getTokenUsage(): TokenUsage
}

/**
* {@inheritDoc}
* Gets the provider metadata.
*
* @since n.e.x.t
*
* @return ProviderMetadata The provider metadata.
*/
public function getProviderMetadata(): array
public function getProviderMetadata(): ProviderMetadata
{
return $this->providerMetadata;
}

/**
* Gets the model metadata.
*
* @since n.e.x.t
*
* @return ModelMetadata The model metadata.
*/
public function getModelMetadata(): ModelMetadata
{
return $this->modelMetadata;
}

/**
* {@inheritDoc}
*
* @since n.e.x.t
*/
public function getAdditionalData(): array
{
return $this->additionalData;
}

/**
* Gets the total number of candidates.
*
Expand Down Expand Up @@ -387,13 +439,21 @@ public static function getJsonSchema(): array
'description' => 'The generated candidates.',
],
self::KEY_TOKEN_USAGE => TokenUsage::getJsonSchema(),
self::KEY_PROVIDER_METADATA => [
self::KEY_PROVIDER_METADATA => ProviderMetadata::getJsonSchema(),
self::KEY_MODEL_METADATA => ModelMetadata::getJsonSchema(),
self::KEY_ADDITIONAL_DATA => [
'type' => 'object',
'additionalProperties' => true,
'description' => 'Provider-specific metadata.',
'description' => 'Additional data included in the API response.',
],
],
'required' => [self::KEY_ID, self::KEY_CANDIDATES, self::KEY_TOKEN_USAGE],
'required' => [
self::KEY_ID,
self::KEY_CANDIDATES,
self::KEY_TOKEN_USAGE,
self::KEY_PROVIDER_METADATA,
self::KEY_MODEL_METADATA
],
];
}

Expand All @@ -410,7 +470,9 @@ public function toArray(): array
self::KEY_ID => $this->id,
self::KEY_CANDIDATES => array_map(fn(Candidate $candidate) => $candidate->toArray(), $this->candidates),
self::KEY_TOKEN_USAGE => $this->tokenUsage->toArray(),
self::KEY_PROVIDER_METADATA => $this->providerMetadata,
self::KEY_PROVIDER_METADATA => $this->providerMetadata->toArray(),
self::KEY_MODEL_METADATA => $this->modelMetadata->toArray(),
self::KEY_ADDITIONAL_DATA => $this->additionalData,
];
}

Expand All @@ -421,7 +483,13 @@ public function toArray(): array
*/
public static function fromArray(array $array): self
{
static::validateFromArrayData($array, [self::KEY_ID, self::KEY_CANDIDATES, self::KEY_TOKEN_USAGE]);
static::validateFromArrayData($array, [
self::KEY_ID,
self::KEY_CANDIDATES,
self::KEY_TOKEN_USAGE,
self::KEY_PROVIDER_METADATA,
self::KEY_MODEL_METADATA
]);

$candidates = array_map(
fn(array $candidateData) => Candidate::fromArray($candidateData),
Expand All @@ -432,7 +500,9 @@ public static function fromArray(array $array): self
$array[self::KEY_ID],
$candidates,
TokenUsage::fromArray($array[self::KEY_TOKEN_USAGE]),
$array[self::KEY_PROVIDER_METADATA] ?? []
ProviderMetadata::fromArray($array[self::KEY_PROVIDER_METADATA]),
ModelMetadata::fromArray($array[self::KEY_MODEL_METADATA]),
$array[self::KEY_ADDITIONAL_DATA] ?? []
);
}
}
Loading