Skip to content
This repository was archived by the owner on Jul 16, 2025. It is now read-only.

Commit ef66553

Browse files
committed
feat: add support for hugging face inference api
1 parent c4deb6d commit ef66553

File tree

10 files changed

+270
-0
lines changed

10 files changed

+270
-0
lines changed

.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ AZURE_OPENAI_KEY=
2525
AZURE_LLAMA_BASEURL=
2626
AZURE_LLAMA_KEY=
2727

28+
# Hugging Face Access Token
29+
HUGGINGFACE_KEY=
30+
2831
# For using OpenRouter
2932
OPENROUTER_KEY=
3033

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?php
2+
3+
use PhpLlm\LlmChain\Bridge\HuggingFace\Model;
4+
use PhpLlm\LlmChain\Bridge\HuggingFace\PlatformFactory;
5+
use PhpLlm\LlmChain\Bridge\HuggingFace\Task;
6+
use Symfony\Component\Dotenv\Dotenv;
7+
8+
require_once dirname(__DIR__).'/vendor/autoload.php';
9+
(new Dotenv())->loadEnv(dirname(__DIR__).'/.env');
10+
11+
if (empty($_ENV['HUGGINGFACE_KEY'])) {
12+
echo 'Please set the HUGGINGFACE_KEY environment variable.'.PHP_EOL;
13+
exit(1);
14+
}
15+
16+
$platform = PlatformFactory::create($_ENV['HUGGINGFACE_KEY']);
17+
$model = new Model('ProsusAI/finbert');
18+
19+
$response = $platform->request($model, 'I like you. I love you.', [
20+
'task' => Task::TEXT_CLASSIFICATION,
21+
]);
22+
23+
dump($response->getContent());

src/Bridge/HuggingFace/Model.php

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
use PhpLlm\LlmChain\Model\Model as BaseModel;
8+
9+
final readonly class Model implements BaseModel
10+
{
11+
/**
12+
* @param string $name the name of the model is optional with HuggingFace
13+
* @param array<string, mixed> $options
14+
*/
15+
public function __construct(
16+
private ?string $name = null,
17+
private array $options = [],
18+
) {
19+
}
20+
21+
public function getName(): string
22+
{
23+
return $this->name ?? '';
24+
}
25+
26+
public function getOptions(): array
27+
{
28+
return $this->options;
29+
}
30+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
use PhpLlm\LlmChain\Model\Model as BaseModel;
8+
use PhpLlm\LlmChain\Platform\ModelClient as PlatformModelClient;
9+
use Symfony\Component\HttpClient\EventSourceHttpClient;
10+
use Symfony\Contracts\HttpClient\HttpClientInterface;
11+
use Symfony\Contracts\HttpClient\ResponseInterface;
12+
use Webmozart\Assert\Assert;
13+
14+
final readonly class ModelClient implements PlatformModelClient
15+
{
16+
private EventSourceHttpClient $httpClient;
17+
18+
public function __construct(
19+
HttpClientInterface $httpClient,
20+
private string $provider,
21+
#[\SensitiveParameter]
22+
private string $apiKey,
23+
) {
24+
$this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
25+
}
26+
27+
public function supports(BaseModel $model, object|array|string $input): bool
28+
{
29+
return $model instanceof Model;
30+
}
31+
32+
public function request(BaseModel $model, object|array|string $input, array $options = []): ResponseInterface
33+
{
34+
Assert::isInstanceOf($model, Model::class);
35+
$url = sprintf('https://router.huggingface.co/%s/models/%s', $this->provider, $model->getName());
36+
37+
return $this->httpClient->request('POST', $url, [
38+
'auth_bearer' => $this->apiKey,
39+
'headers' => ['Content-Type' => 'application/json'],
40+
'json' => $this->getPayload($input, $options),
41+
]);
42+
}
43+
44+
/**
45+
* @param array<mixed>|string|object $input
46+
* @param array<string, mixed> $options
47+
*
48+
* @return array<string, mixed>
49+
*/
50+
public function getPayload(object|array|string $input, array $options = []): array
51+
{
52+
$task = $options['task'] ?? null;
53+
54+
if (Task::TEXT_CLASSIFICATION === $task) {
55+
return [
56+
'inputs' => $input,
57+
];
58+
}
59+
60+
throw new \InvalidArgumentException('Unsupported task: '.$task);
61+
}
62+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace\Output;
6+
7+
final readonly class Classification
8+
{
9+
public function __construct(
10+
public string $label,
11+
public float $score,
12+
) {
13+
}
14+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace\Output;
6+
7+
final class ClassificationResult
8+
{
9+
/**
10+
* @param Classification[] $classifications
11+
*/
12+
public function __construct(
13+
public array $classifications,
14+
) {
15+
}
16+
17+
/**
18+
* @param array<array{label: string, score: float}> $data
19+
*/
20+
public static function collectionFromArray(array $data): self
21+
{
22+
return new self(array_map(
23+
fn (array $item) => new Classification($item['label'], $item['score']),
24+
$data,
25+
));
26+
}
27+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
use PhpLlm\LlmChain\Platform;
8+
use Symfony\Component\HttpClient\EventSourceHttpClient;
9+
use Symfony\Contracts\HttpClient\HttpClientInterface;
10+
11+
final readonly class PlatformFactory
12+
{
13+
public static function create(
14+
#[\SensitiveParameter]
15+
string $apiKey,
16+
string $provider = Provider::HF_INFERENCE,
17+
?HttpClientInterface $httpClient = null,
18+
): Platform {
19+
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
20+
21+
return new Platform([new ModelClient($httpClient, $provider, $apiKey)], [new ResponseConverter()]);
22+
}
23+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
interface Provider
8+
{
9+
public const CEREBRAS = 'cerebras';
10+
public const FAL_AI = 'fal-ai';
11+
public const FIREWORKS = 'fireworks-ai';
12+
public const HYPERBOLIC = 'hyperbolic';
13+
public const HF_INFERENCE = 'hf-inference';
14+
public const NEBIUS = 'nebius';
15+
public const NOVITA = 'novita';
16+
public const REPLICATE = 'replicate';
17+
public const SAMBA_NOVA = 'sambanova';
18+
public const TOGETHER = 'together';
19+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
use PhpLlm\LlmChain\Bridge\HuggingFace\Output\ClassificationResult;
8+
use PhpLlm\LlmChain\Model\Model as BaseModel;
9+
use PhpLlm\LlmChain\Model\Response\ResponseInterface as LlmResponse;
10+
use PhpLlm\LlmChain\Model\Response\StructuredResponse;
11+
use PhpLlm\LlmChain\Platform\ResponseConverter as PlatformResponseConverter;
12+
use Symfony\Contracts\HttpClient\ResponseInterface;
13+
14+
final readonly class ResponseConverter implements PlatformResponseConverter
15+
{
16+
public function supports(BaseModel $model, array|string|object $input): bool
17+
{
18+
return $model instanceof Model;
19+
}
20+
21+
public function convert(ResponseInterface $response, array $options = []): LlmResponse
22+
{
23+
$data = $response->toArray();
24+
$task = $options['task'] ?? null;
25+
26+
if (Task::TEXT_CLASSIFICATION === $task) {
27+
return new StructuredResponse(ClassificationResult::collectionFromArray(reset($data) ?? []));
28+
}
29+
30+
throw new \RuntimeException(sprintf('Unsupported task: %s', $task));
31+
}
32+
}

src/Bridge/HuggingFace/Task.php

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Bridge\HuggingFace;
6+
7+
interface Task
8+
{
9+
public const AUDIO_CLASSIFICATION = 'audio-classification';
10+
public const AUDIO_TO_AUDIO = 'audio-to-audio';
11+
public const AUTOMATIC_SPEECH_RECOGNITION = 'automatic-speech-recognition';
12+
public const CHAT_COMPLETION = 'chat-completion';
13+
public const DOCUMENT_QUESTION_ANSWERING = 'document-question-answering';
14+
public const FEATURE_EXTRACTION = 'feature-extraction';
15+
public const FILL_MASK = 'fill-mask';
16+
public const IMAGE_CLASSIFICATION = 'image-classification';
17+
public const IMAGE_SEGMENTATION = 'image-segmentation';
18+
public const IMAGE_TO_IMAGE = 'image-to-image';
19+
public const IMAGE_TO_TEXT = 'image-to-text';
20+
public const OBJECT_DETECTION = 'object-detection';
21+
public const QUESTION_ANSWERING = 'question-answering';
22+
public const SENTENCE_SIMILARITY = 'sentence-similarity';
23+
public const SUMMARIZATION = 'summarization';
24+
public const TABLE_QUESTION_ANSWERING = 'table-question-answering';
25+
public const TABULAR_CLASSIFICATION = 'tabular-classification';
26+
public const TABULAR_REGRESSION = 'tabular-regression';
27+
public const TEXT_CLASSIFICATION = 'text-classification';
28+
public const TEXT_GENERATION = 'text-generation';
29+
public const TEXT_TO_IMAGE = 'text-to-image';
30+
public const TEXT_TO_VIDEO = 'text-to-video';
31+
public const TEXT_TO_SPEECH = 'text-to-speech';
32+
public const TOKEN_CLASSIFICATION = 'token-classification';
33+
public const TRANSLATION = 'translation';
34+
public const VISUAL_QUESTION_ANSWERING = 'visual-question-answering';
35+
public const ZERO_SHOT_CLASSIFICATION = 'zero-shot-classification';
36+
public const ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification';
37+
}

0 commit comments

Comments
 (0)