Skip to content

Commit 16ca404

Browse files
Merge pull request #62 from WordPress/prompt-builder-provider-support
2 parents e5e1f18 + cce0fb3 commit 16ca404

File tree

6 files changed

+636
-93
lines changed

6 files changed

+636
-93
lines changed

src/Builders/PromptBuilder.php

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class PromptBuilder
5757
*/
5858
protected ?ModelInterface $model = null;
5959

60+
/**
61+
* @var string|null The provider ID or class name.
62+
*/
63+
protected ?string $providerIdOrClassName = null;
64+
6065
/**
6166
* @var ModelConfig The model configuration.
6267
*/
@@ -198,6 +203,20 @@ public function usingModel(ModelInterface $model): self
198203
return $this;
199204
}
200205

206+
/**
207+
* Sets the provider to use for generation.
208+
*
209+
* @since n.e.x.t
210+
*
211+
* @param string $providerIdOrClassName The provider ID or class name.
212+
* @return self
213+
*/
214+
public function usingProvider(string $providerIdOrClassName): self
215+
{
216+
$this->providerIdOrClassName = $providerIdOrClassName;
217+
return $this;
218+
}
219+
201220
/**
202221
* Sets the system instruction.
203222
*
@@ -930,28 +949,57 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface
930949
}
931950

932951
// Find a suitable model based on requirements
933-
$modelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
952+
if ($this->providerIdOrClassName === null) {
953+
$providerModelsMetadata = $this->registry->findModelsMetadataForSupport($requirements);
934954

935-
if (empty($modelsMetadata)) {
936-
throw new InvalidArgumentException(
937-
'No models found that support the required capabilities and options for this prompt. ' .
938-
'Required capabilities: ' . implode(', ', array_map(function ($cap) {
939-
return $cap->value;
940-
}, $requirements->getRequiredCapabilities())) .
941-
'. Required options: ' . implode(', ', array_map(function ($opt) {
942-
return $opt->getName()->value . '=' . json_encode($opt->getValue());
943-
}, $requirements->getRequiredOptions()))
955+
if (empty($providerModelsMetadata)) {
956+
throw new InvalidArgumentException(
957+
sprintf(
958+
'No models found that support the required capabilities and options for this prompt. ' .
959+
'Required capabilities: %s. Required options: %s',
960+
implode(', ', array_map(function ($cap) {
961+
return $cap->value;
962+
}, $requirements->getRequiredCapabilities())),
963+
implode(', ', array_map(function ($opt) {
964+
return $opt->getName()->value . '=' . json_encode($opt->getValue());
965+
}, $requirements->getRequiredOptions()))
966+
)
967+
);
968+
}
969+
970+
$firstProviderModels = $providerModelsMetadata[0];
971+
$provider = $firstProviderModels->getProvider()->getId();
972+
$modelMetadata = $firstProviderModels->getModels()[0];
973+
} else {
974+
$modelsMetadata = $this->registry->findProviderModelsMetadataForSupport(
975+
$this->providerIdOrClassName,
976+
$requirements
944977
);
945-
}
946978

947-
// Get the first available model from the first provider
948-
$firstProviderModels = $modelsMetadata[0];
949-
$firstModelMetadata = $firstProviderModels->getModels()[0];
979+
if (empty($modelsMetadata)) {
980+
throw new InvalidArgumentException(
981+
sprintf(
982+
'No models found for %s that support the required capabilities and options for this prompt. ' .
983+
'Required capabilities: %s. Required options: %s',
984+
$this->providerIdOrClassName,
985+
implode(', ', array_map(function ($cap) {
986+
return $cap->value;
987+
}, $requirements->getRequiredCapabilities())),
988+
implode(', ', array_map(function ($opt) {
989+
return $opt->getName()->value . '=' . json_encode($opt->getValue());
990+
}, $requirements->getRequiredOptions()))
991+
)
992+
);
993+
}
994+
995+
$provider = $this->providerIdOrClassName;
996+
$modelMetadata = $modelsMetadata[0];
997+
}
950998

951999
// Get the model instance from the provider
9521000
return $this->registry->getProviderModel(
953-
$firstProviderModels->getProvider()->getId(),
954-
$firstModelMetadata->getId(),
1001+
$provider,
1002+
$modelMetadata->getId(),
9551003
$this->modelConfig
9561004
);
9571005
}

src/Results/Contracts/ResultInterface.php

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
namespace WordPress\AiClient\Results\Contracts;
66

7+
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
8+
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
79
use WordPress\AiClient\Results\DTO\TokenUsage;
810

911
/**
@@ -34,12 +36,30 @@ public function getId(): string;
3436
*/
3537
public function getTokenUsage(): TokenUsage;
3638

39+
/**
40+
* Gets the provider metadata.
41+
*
42+
* @since n.e.x.t
43+
*
44+
* @return ProviderMetadata The provider metadata.
45+
*/
46+
public function getProviderMetadata(): ProviderMetadata;
47+
48+
/**
49+
* Gets the model metadata.
50+
*
51+
* @since n.e.x.t
52+
*
53+
* @return ModelMetadata The model metadata.
54+
*/
55+
public function getModelMetadata(): ModelMetadata;
56+
3757
/**
3858
* Gets provider-specific metadata.
3959
*
4060
* @since n.e.x.t
4161
*
4262
* @return array<string, mixed> Provider metadata.
4363
*/
44-
public function getProviderMetadata(): array;
64+
public function getAdditionalData(): array;
4565
}

src/Results/DTO/GenerativeAiResult.php

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
use WordPress\AiClient\Common\AbstractDataTransferObject;
1010
use WordPress\AiClient\Files\DTO\File;
1111
use WordPress\AiClient\Messages\DTO\Message;
12+
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
13+
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
1214
use WordPress\AiClient\Results\Contracts\ResultInterface;
1315

1416
/**
@@ -21,12 +23,16 @@
2123
*
2224
* @phpstan-import-type CandidateArrayShape from Candidate
2325
* @phpstan-import-type TokenUsageArrayShape from TokenUsage
26+
* @phpstan-import-type ProviderMetadataArrayShape from ProviderMetadata
27+
* @phpstan-import-type ModelMetadataArrayShape from ModelMetadata
2428
*
2529
* @phpstan-type GenerativeAiResultArrayShape array{
2630
* id: string,
2731
* candidates: array<CandidateArrayShape>,
2832
* tokenUsage: TokenUsageArrayShape,
29-
* providerMetadata?: array<string, mixed>
33+
* providerMetadata: ProviderMetadataArrayShape,
34+
* modelMetadata: ModelMetadataArrayShape,
35+
* additionalData?: array<string, mixed>
3036
* }
3137
*
3238
* @extends AbstractDataTransferObject<GenerativeAiResultArrayShape>
@@ -37,6 +43,8 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
3743
public const KEY_CANDIDATES = 'candidates';
3844
public const KEY_TOKEN_USAGE = 'tokenUsage';
3945
public const KEY_PROVIDER_METADATA = 'providerMetadata';
46+
public const KEY_MODEL_METADATA = 'modelMetadata';
47+
public const KEY_ADDITIONAL_DATA = 'additionalData';
4048
/**
4149
* @var string Unique identifier for this result.
4250
*/
@@ -53,9 +61,19 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
5361
private TokenUsage $tokenUsage;
5462

5563
/**
56-
* @var array<string, mixed> Provider-specific metadata.
64+
* @var ProviderMetadata Provider metadata.
5765
*/
58-
private array $providerMetadata;
66+
private ProviderMetadata $providerMetadata;
67+
68+
/**
69+
* @var ModelMetadata Model metadata.
70+
*/
71+
private ModelMetadata $modelMetadata;
72+
73+
/**
74+
* @var array<string, mixed> Additional data.
75+
*/
76+
private array $additionalData;
5977

6078
/**
6179
* Constructor.
@@ -65,11 +83,19 @@ class GenerativeAiResult extends AbstractDataTransferObject implements ResultInt
6583
* @param string $id Unique identifier for this result.
6684
* @param Candidate[] $candidates The generated candidates.
6785
* @param TokenUsage $tokenUsage Token usage statistics.
68-
* @param array<string, mixed> $providerMetadata Provider-specific metadata.
86+
* @param ProviderMetadata $providerMetadata Provider metadata.
87+
* @param ModelMetadata $modelMetadata Model metadata.
88+
* @param array<string, mixed> $additionalData Additional data.
6989
* @throws InvalidArgumentException If no candidates provided.
7090
*/
71-
public function __construct(string $id, array $candidates, TokenUsage $tokenUsage, array $providerMetadata = [])
72-
{
91+
public function __construct(
92+
string $id,
93+
array $candidates,
94+
TokenUsage $tokenUsage,
95+
ProviderMetadata $providerMetadata,
96+
ModelMetadata $modelMetadata,
97+
array $additionalData = []
98+
) {
7399
if (empty($candidates)) {
74100
throw new InvalidArgumentException('At least one candidate must be provided');
75101
}
@@ -78,6 +104,8 @@ public function __construct(string $id, array $candidates, TokenUsage $tokenUsag
78104
$this->candidates = $candidates;
79105
$this->tokenUsage = $tokenUsage;
80106
$this->providerMetadata = $providerMetadata;
107+
$this->modelMetadata = $modelMetadata;
108+
$this->additionalData = $additionalData;
81109
}
82110

83111
/**
@@ -113,15 +141,39 @@ public function getTokenUsage(): TokenUsage
113141
}
114142

115143
/**
116-
* {@inheritDoc}
144+
* Gets the provider metadata.
117145
*
118146
* @since n.e.x.t
147+
*
148+
* @return ProviderMetadata The provider metadata.
119149
*/
120-
public function getProviderMetadata(): array
150+
public function getProviderMetadata(): ProviderMetadata
121151
{
122152
return $this->providerMetadata;
123153
}
124154

155+
/**
156+
* Gets the model metadata.
157+
*
158+
* @since n.e.x.t
159+
*
160+
* @return ModelMetadata The model metadata.
161+
*/
162+
public function getModelMetadata(): ModelMetadata
163+
{
164+
return $this->modelMetadata;
165+
}
166+
167+
/**
168+
* {@inheritDoc}
169+
*
170+
* @since n.e.x.t
171+
*/
172+
public function getAdditionalData(): array
173+
{
174+
return $this->additionalData;
175+
}
176+
125177
/**
126178
* Gets the total number of candidates.
127179
*
@@ -387,13 +439,21 @@ public static function getJsonSchema(): array
387439
'description' => 'The generated candidates.',
388440
],
389441
self::KEY_TOKEN_USAGE => TokenUsage::getJsonSchema(),
390-
self::KEY_PROVIDER_METADATA => [
442+
self::KEY_PROVIDER_METADATA => ProviderMetadata::getJsonSchema(),
443+
self::KEY_MODEL_METADATA => ModelMetadata::getJsonSchema(),
444+
self::KEY_ADDITIONAL_DATA => [
391445
'type' => 'object',
392446
'additionalProperties' => true,
393-
'description' => 'Provider-specific metadata.',
447+
'description' => 'Additional data included in the API response.',
394448
],
395449
],
396-
'required' => [self::KEY_ID, self::KEY_CANDIDATES, self::KEY_TOKEN_USAGE],
450+
'required' => [
451+
self::KEY_ID,
452+
self::KEY_CANDIDATES,
453+
self::KEY_TOKEN_USAGE,
454+
self::KEY_PROVIDER_METADATA,
455+
self::KEY_MODEL_METADATA
456+
],
397457
];
398458
}
399459

@@ -410,7 +470,9 @@ public function toArray(): array
410470
self::KEY_ID => $this->id,
411471
self::KEY_CANDIDATES => array_map(fn(Candidate $candidate) => $candidate->toArray(), $this->candidates),
412472
self::KEY_TOKEN_USAGE => $this->tokenUsage->toArray(),
413-
self::KEY_PROVIDER_METADATA => $this->providerMetadata,
473+
self::KEY_PROVIDER_METADATA => $this->providerMetadata->toArray(),
474+
self::KEY_MODEL_METADATA => $this->modelMetadata->toArray(),
475+
self::KEY_ADDITIONAL_DATA => $this->additionalData,
414476
];
415477
}
416478

@@ -421,7 +483,13 @@ public function toArray(): array
421483
*/
422484
public static function fromArray(array $array): self
423485
{
424-
static::validateFromArrayData($array, [self::KEY_ID, self::KEY_CANDIDATES, self::KEY_TOKEN_USAGE]);
486+
static::validateFromArrayData($array, [
487+
self::KEY_ID,
488+
self::KEY_CANDIDATES,
489+
self::KEY_TOKEN_USAGE,
490+
self::KEY_PROVIDER_METADATA,
491+
self::KEY_MODEL_METADATA
492+
]);
425493

426494
$candidates = array_map(
427495
fn(array $candidateData) => Candidate::fromArray($candidateData),
@@ -432,7 +500,9 @@ public static function fromArray(array $array): self
432500
$array[self::KEY_ID],
433501
$candidates,
434502
TokenUsage::fromArray($array[self::KEY_TOKEN_USAGE]),
435-
$array[self::KEY_PROVIDER_METADATA] ?? []
503+
ProviderMetadata::fromArray($array[self::KEY_PROVIDER_METADATA]),
504+
ModelMetadata::fromArray($array[self::KEY_MODEL_METADATA]),
505+
$array[self::KEY_ADDITIONAL_DATA] ?? []
436506
);
437507
}
438508
}

0 commit comments

Comments
 (0)