Skip to content

Commit 789357f

Browse files
committed
[Store][Postgres] allow store initialization with utilized distance
1 parent 11a4e56 commit 789357f

File tree

3 files changed

+158
-25
lines changed

3 files changed

+158
-25
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <fabien@symfony.com>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Store\Bridge\Postgres;
13+
14+
use OskarStark\Enum\Trait\Comparable;
15+
16+
/**
17+
* @author Denis Zunke <denis.zunke@gmail.com>
18+
*/
19+
enum Distance: string
20+
{
21+
use Comparable;
22+
23+
case Cosine = 'cosine';
24+
case InnerProduct = 'inner_product';
25+
case L1 = 'l1';
26+
case L2 = 'l2';
27+
28+
public function getComparisonSign(): string
29+
{
30+
return match ($this) {
31+
self::Cosine => '<=>',
32+
self::InnerProduct => '<#>',
33+
self::L1 => '<+>',
34+
self::L2 => '<->',
35+
};
36+
}
37+
}

src/store/src/Bridge/Postgres/Store.php

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,32 @@ public function __construct(
3434
private \PDO $connection,
3535
private string $tableName,
3636
private string $vectorFieldName = 'embedding',
37+
private Distance $distance = Distance::L2,
3738
) {
3839
}
3940

40-
public static function fromPdo(\PDO $connection, string $tableName, string $vectorFieldName = 'embedding'): self
41-
{
42-
return new self($connection, $tableName, $vectorFieldName);
41+
public static function fromPdo(
42+
\PDO $connection,
43+
string $tableName,
44+
string $vectorFieldName = 'embedding',
45+
Distance $distance = Distance::L2,
46+
): self {
47+
return new self($connection, $tableName, $vectorFieldName, $distance);
4348
}
4449

45-
public static function fromDbal(Connection $connection, string $tableName, string $vectorFieldName = 'embedding'): self
46-
{
50+
public static function fromDbal(
51+
Connection $connection,
52+
string $tableName,
53+
string $vectorFieldName = 'embedding',
54+
Distance $distance = Distance::L2,
55+
): self {
4756
$pdo = $connection->getNativeConnection();
4857

4958
if (!$pdo instanceof \PDO) {
5059
throw new InvalidArgumentException('Only DBAL connections using PDO driver are supported.');
5160
}
5261

53-
return self::fromPdo($pdo, $tableName, $vectorFieldName);
62+
return self::fromPdo($pdo, $tableName, $vectorFieldName, $distance);
5463
}
5564

5665
public function add(VectorDocument ...$documents): void
@@ -84,16 +93,18 @@ public function add(VectorDocument ...$documents): void
8493
*/
8594
public function query(Vector $vector, array $options = [], ?float $minScore = null): array
8695
{
87-
$sql = \sprintf(
88-
'SELECT id, %s AS embedding, metadata, (%s <-> :embedding) AS score
89-
FROM %s
90-
%s
91-
ORDER BY score ASC
92-
LIMIT %d',
96+
$sql = \sprintf(<<<SQL
97+
SELECT id, %s AS embedding, metadata, (%s %s :embedding) AS score
98+
FROM %s
99+
%s
100+
ORDER BY score ASC
101+
LIMIT %d
102+
SQL,
93103
$this->vectorFieldName,
94104
$this->vectorFieldName,
105+
$this->distance->getComparisonSign(),
95106
$this->tableName,
96-
null !== $minScore ? "WHERE ({$this->vectorFieldName} <-> :embedding) >= :minScore" : '',
107+
null !== $minScore ? "WHERE ({$this->vectorFieldName} {$this->distance->getComparisonSign()} :embedding) >= :minScore" : '',
97108
$options['limit'] ?? 5,
98109
);
99110
$statement = $this->connection->prepare($sql);

src/store/tests/Bridge/Postgres/StoreTest.php

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use PHPUnit\Framework\Attributes\CoversClass;
1616
use PHPUnit\Framework\TestCase;
1717
use Symfony\AI\Platform\Vector\Vector;
18+
use Symfony\AI\Store\Bridge\Postgres\Distance;
1819
use Symfony\AI\Store\Bridge\Postgres\Store;
1920
use Symfony\AI\Store\Document\Metadata;
2021
use Symfony\AI\Store\Document\VectorDocument;
@@ -33,7 +34,7 @@ private function normalizeQuery(string $query): string
3334
return trim($normalized);
3435
}
3536

36-
public function testAddSingleDocument(): void
37+
public function testAddSingleDocument()
3738
{
3839
$pdo = $this->createMock(\PDO::class);
3940
$statement = $this->createMock(\PDOStatement::class);
@@ -65,7 +66,7 @@ public function testAddSingleDocument(): void
6566
$store->add($document);
6667
}
6768

68-
public function testAddMultipleDocuments(): void
69+
public function testAddMultipleDocuments()
6970
{
7071
$pdo = $this->createMock(\PDO::class);
7172
$statement = $this->createMock(\PDOStatement::class);
@@ -105,7 +106,7 @@ public function testAddMultipleDocuments(): void
105106
$store->add($document1, $document2);
106107
}
107108

108-
public function testQueryWithoutMinScore(): void
109+
public function testQueryWithoutMinScore()
109110
{
110111
$pdo = $this->createMock(\PDO::class);
111112
$statement = $this->createMock(\PDOStatement::class);
@@ -152,7 +153,54 @@ public function testQueryWithoutMinScore(): void
152153
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
153154
}
154155

155-
public function testQueryWithMinScore(): void
156+
public function testQueryChangedDistanceMethodWithoutMinScore()
157+
{
158+
$pdo = $this->createMock(\PDO::class);
159+
$statement = $this->createMock(\PDOStatement::class);
160+
161+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
162+
163+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
164+
FROM embeddings_table
165+
166+
ORDER BY score ASC
167+
LIMIT 5';
168+
169+
$pdo->expects($this->once())
170+
->method('prepare')
171+
->with($this->callback(function ($sql) use ($expectedSql) {
172+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
173+
}))
174+
->willReturn($statement);
175+
176+
$uuid = Uuid::v4();
177+
178+
$statement->expects($this->once())
179+
->method('execute')
180+
->with(['embedding' => '[0.1,0.2,0.3]']);
181+
182+
$statement->expects($this->once())
183+
->method('fetchAll')
184+
->with(\PDO::FETCH_ASSOC)
185+
->willReturn([
186+
[
187+
'id' => $uuid->toRfc4122(),
188+
'embedding' => '[0.1,0.2,0.3]',
189+
'metadata' => json_encode(['title' => 'Test Document']),
190+
'score' => 0.95,
191+
],
192+
]);
193+
194+
$results = $store->query(new Vector([0.1, 0.2, 0.3]));
195+
196+
$this->assertCount(1, $results);
197+
$this->assertInstanceOf(VectorDocument::class, $results[0]);
198+
$this->assertEquals($uuid, $results[0]->id);
199+
$this->assertSame(0.95, $results[0]->score);
200+
$this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy());
201+
}
202+
203+
public function testQueryWithMinScore()
156204
{
157205
$pdo = $this->createMock(\PDO::class);
158206
$statement = $this->createMock(\PDOStatement::class);
@@ -189,7 +237,44 @@ public function testQueryWithMinScore(): void
189237
$this->assertCount(0, $results);
190238
}
191239

192-
public function testQueryWithCustomLimit(): void
240+
public function testQueryWithMinScoreAndDifferentDistance()
241+
{
242+
$pdo = $this->createMock(\PDO::class);
243+
$statement = $this->createMock(\PDOStatement::class);
244+
245+
$store = new Store($pdo, 'embeddings_table', 'embedding', Distance::Cosine);
246+
247+
$expectedSql = 'SELECT id, embedding AS embedding, metadata, (embedding <=> :embedding) AS score
248+
FROM embeddings_table
249+
WHERE (embedding <=> :embedding) >= :minScore
250+
ORDER BY score ASC
251+
LIMIT 5';
252+
253+
$pdo->expects($this->once())
254+
->method('prepare')
255+
->with($this->callback(function ($sql) use ($expectedSql) {
256+
return $this->normalizeQuery($sql) === $this->normalizeQuery($expectedSql);
257+
}))
258+
->willReturn($statement);
259+
260+
$statement->expects($this->once())
261+
->method('execute')
262+
->with([
263+
'embedding' => '[0.1,0.2,0.3]',
264+
'minScore' => 0.8,
265+
]);
266+
267+
$statement->expects($this->once())
268+
->method('fetchAll')
269+
->with(\PDO::FETCH_ASSOC)
270+
->willReturn([]);
271+
272+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), [], 0.8);
273+
274+
$this->assertCount(0, $results);
275+
}
276+
277+
public function testQueryWithCustomLimit()
193278
{
194279
$pdo = $this->createMock(\PDO::class);
195280
$statement = $this->createMock(\PDOStatement::class);
@@ -223,7 +308,7 @@ public function testQueryWithCustomLimit(): void
223308
$this->assertCount(0, $results);
224309
}
225310

226-
public function testQueryWithCustomVectorFieldName(): void
311+
public function testQueryWithCustomVectorFieldName()
227312
{
228313
$pdo = $this->createMock(\PDO::class);
229314
$statement = $this->createMock(\PDOStatement::class);
@@ -255,7 +340,7 @@ public function testQueryWithCustomVectorFieldName(): void
255340
$this->assertCount(0, $results);
256341
}
257342

258-
public function testInitialize(): void
343+
public function testInitialize()
259344
{
260345
$pdo = $this->createMock(\PDO::class);
261346

@@ -283,7 +368,7 @@ public function testInitialize(): void
283368
$store->initialize();
284369
}
285370

286-
public function testInitializeWithCustomVectorSize(): void
371+
public function testInitializeWithCustomVectorSize()
287372
{
288373
$pdo = $this->createMock(\PDO::class);
289374

@@ -306,7 +391,7 @@ public function testInitializeWithCustomVectorSize(): void
306391
$store->initialize(['vector_size' => 768]);
307392
}
308393

309-
public function testFromPdo(): void
394+
public function testFromPdo()
310395
{
311396
$pdo = $this->createMock(\PDO::class);
312397

@@ -315,7 +400,7 @@ public function testFromPdo(): void
315400
$this->assertInstanceOf(Store::class, $store);
316401
}
317402

318-
public function testFromDbalWithPdoDriver(): void
403+
public function testFromDbalWithPdoDriver()
319404
{
320405
$pdo = $this->createMock(\PDO::class);
321406
$connection = $this->createMock(Connection::class);
@@ -329,7 +414,7 @@ public function testFromDbalWithPdoDriver(): void
329414
$this->assertInstanceOf(Store::class, $store);
330415
}
331416

332-
public function testFromDbalWithNonPdoDriverThrowsException(): void
417+
public function testFromDbalWithNonPdoDriverThrowsException()
333418
{
334419
$connection = $this->createMock(Connection::class);
335420

@@ -343,7 +428,7 @@ public function testFromDbalWithNonPdoDriverThrowsException(): void
343428
Store::fromDbal($connection, 'test_table');
344429
}
345430

346-
public function testQueryWithNullMetadata(): void
431+
public function testQueryWithNullMetadata()
347432
{
348433
$pdo = $this->createMock(\PDO::class);
349434
$statement = $this->createMock(\PDOStatement::class);

0 commit comments

Comments
 (0)