Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
pulsejet committed Oct 13, 2024
1 parent b8561f8 commit e8ae9e6
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 57 deletions.
105 changes: 59 additions & 46 deletions lib/Command/AI.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

namespace OCA\Memories\Command;

use OC\DB\QueryBuilder\QueryBuilder;
use OCA\Memories\Db\FsManager;
use OCA\Memories\Db\SQL;
use OCA\Memories\Db\TimelineQuery;
use OCA\Memories\Db\TimelineRoot;
use OCA\Memories\Util;
Expand All @@ -43,6 +43,7 @@

const API_IMAGES = '/images';
const API_TEXT = '/text';
const VECTOR_SIZE = 768;

class AIOpts
{
Expand Down Expand Up @@ -85,91 +86,93 @@ public function search(string $prompt): array
$query = $this->connection->getQueryBuilder();

$classlist = array_map(static fn (array $class): int => $class['index'], $response['classes']);
// $classlist = array_slice($classlist, 0, 1);
// $classlist = \array_slice($classlist, 0, 8);

$classQuery = $this->connection->getQueryBuilder();
$classQuery->select('c.word')
->from('memories_ss_class', 'c')
->where($query->expr()->andX(
$query->expr()->eq('c.fileid', 'v.fileid'),
$query->expr()->orX(
...array_map(static fn ($idx) =>
$query->expr()->eq('c.class', $query->expr()->literal($idx)),
$classlist)
),
));
->where($classQuery->expr()->andX(
$classQuery->expr()->eq('c.fileid', 'v.fileid'),
$classQuery->expr()->in('c.class', array_map(static fn ($idx) => $classQuery->expr()->literal($idx), $classlist)),
))
;

$subquery = $this->connection->getQueryBuilder();
$subquery->select('v.fileid')
->from('memories_ss_vectors', 'v')
->where($subquery->createFunction("EXISTS ({$classQuery->getSql()})"))
->where(SQL::exists($query, $classQuery))
->groupBy('v.fileid')
;

// Take vector projection
$components = [];
foreach ($response['embedding'] as $i => $value) {
$value = number_format($value, 6);
$components[] = "v.v{$i}*{$value}";
$components[] = "(v.v{$i}*({$value}))";
}

// Divide the operators into chunks of 96 each
$sums = array_chunk($components, 96);
// Divide the operators into chunks of 48 each
$sums = array_chunk($components, 48);

// Add the sum of each chunk
for ($i = 0; $i < \count($sums); ++$i) {
$sum = implode('+', $sums[$i]);
$subquery->addSelect($subquery->createFunction("({$sum}) as score{$i}"));
$sum = $subquery->createFunction(implode('+', $sums[$i]));
$subquery->selectAlias($sum, "score{$i}");
}

// Create outer query
$query->select('sq.fileid')
->from($query->createFunction("({$subquery->getSQL()}) sq"))
->from(SQL::subquery($query, $subquery, 'sq'))
;

// Add all score sums together
$sum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums)));
$query->addSelect($query->createFunction("({$sum}) as score"));
$finalSum = implode('+', array_map(static fn ($_, $i) => "score{$i}", $sums, array_keys($sums)));
$finalSum = $query->createFunction("({$finalSum})");
$query->selectAlias($finalSum, 'score');

// Filter for scores less than 1
// $query->andWhere($query->createFunction("(({$sum}) > 0.04)"));
$query = SQL::materialize($query, 'fsq');
$query->andWhere($query->expr()->gt('fsq.score', $query->expr()->literal(0.04)));

$query->orderBy('score', 'DESC');
$query->orderBy('fsq.score', 'DESC');

// $query->setMaxResults(8); // batch size
header('Content-Type: text/html');

// SQL::debugQuery($query);

$t1 = microtime(true);

$res = $query->executeQuery()->fetchAll();

// print length and discard after 10
echo "<h1>Results: ".\count($res)."</h1>";
$res = array_slice($res, 0, 10);
echo '<h1>Results: '.\count($res).'</h1>';
$res = \array_slice($res, 0, 10);

$t2 = microtime(true);
echo "<h1>Search took ".(($t2 - $t1)*1000)." ms</h1>";
echo "class list: ".json_encode($response['classes'])."<br>";
echo '<h1>Search took '.(($t2 - $t1) * 1000).' ms</h1>';
echo 'class list: '.json_encode($response['classes']).'<br>';

foreach ($res as &$row) {
$fid = $row['fileid'] = (int) $row['fileid'];
$row['score'] = (float) $row['score'];

$row['score'] = pow(2, $row['score'] * 40);
$row['score'] = 2 ** ($row['score'] * 40);

$p = $this->preview->getPreview($this->fs->getUserFile($fid), 1024, 1024);
$data = $p->getContent();

//get classes for this file
// get classes for this file
$q = $this->connection->getQueryBuilder();
$w = $q->select('word')
->from('memories_ss_class', 'c')
->where($q->expr()->eq('c.fileid', $q->createNamedParameter($fid)))
->executeQuery()
->fetchAll(\PDO::FETCH_COLUMN);
->fetchAll(\PDO::FETCH_COLUMN)
;

echo "<h2>Score: ". $row['score'] . "</h2>";
echo "Row: ".json_encode($row)."<br>";
echo "Classes: ".json_encode($w)."</br>";
echo '<h2>Score: '.$row['score'].'</h2>';
echo 'Row: '.json_encode($row).'<br>';
echo 'Classes: '.json_encode($w).'</br>';
echo "<img src='data:image/jpeg;base64,".base64_encode($data)."'>";
}

Expand Down Expand Up @@ -227,14 +230,18 @@ private function indexUser(IUser $user): void
->from('memories', 'm')
;

$this->tq->joinFilecache($query, $root, true, false, true);
$query = $this->tq->filterFilecache($query, $root, true, false, true);

// Filter by the files that are not indexed by the AI
$query
->leftJoin('m', 'memories_ss_vectors', 'v', $query->expr()->eq('m.fileid', 'v.fileid'))
->where($query->expr()->isNull('v.fileid'))
->setMaxResults(16) // batch size
$vecSq = $this->connection->getQueryBuilder();
$vecSq->select($vecSq->expr()->literal(1))
->from('memories_ss_vectors', 'v')
->where($vecSq->expr()->eq('m.fileid', 'v.fileid'))
;
$query->andWhere(SQL::notExists($query, $vecSq));

// Batch size
$query->setMaxResults(16);

// FileIds inside this folder that need indexing
$objs = Util::transaction(fn () => $this->tq->executeQueryWithCTEs($query)->fetchAll());
Expand All @@ -256,21 +263,23 @@ private function indexSet(Folder $folder, array $objs): void
return;
}

$count = \count($objs);
$this->output->writeln("Indexing {$count} files");

// Get previews for all files
foreach ($objs as &$obj) {
$obj['fileid'] = (int) $obj['fileid'];
$fileid = $obj['fileid'] = (int) $obj['fileid'];
$obj['mtime'] = (int) $obj['mtime'];

try {
// Get file object
$file = $folder->getById($obj['fileid']);
if (empty($file)) {
$this->output->writeln("<error>File not found: {$fileid}</error>");

continue;
}
$file = $file[0];
if (!$file instanceof File) {
$this->output->writeln("<error>Not a file: {$fileid}</error>");

continue;
}

Expand All @@ -285,13 +294,17 @@ private function indexSet(Folder $folder, array $objs): void
$mime = $preview->getMimeType();
$data = base64_encode($content);
$obj['image'] = "data:{$mime};base64,{$data}";

// Log
$this->output->writeln("Indexing {$file->getPath()}");
} catch (\Exception $e) {
$obj['fileid'] = 0; // mark failure
$this->output->writeln("<error>Failed to get preview: {$e->getMessage()}</error>".PHP_EOL);
$this->output->writeln("<error>Failed to get preview: {$e->getMessage()}</error>");
}
}

// Filter out failed files
// TODO: store failure reason
$objs = array_filter($objs, static fn ($obj) => $obj['fileid'] > 0);

// Post to server
Expand Down Expand Up @@ -327,11 +340,11 @@ private function indexSet(Folder $folder, array $objs): void
private function ssStoreResult(array $result, int $fileid, int $mtime): void
{
// Check result
if (768 !== \count($result['embedding'])) {
if (VECTOR_SIZE !== \count($result['embedding'])) {
throw new \Exception('Invalid embedding size');
}

if (\count($result['classes']) === 0) {
if (0 === \count($result['classes'])) {
throw new \Exception('No classes returned.');
}

Expand All @@ -345,8 +358,8 @@ private function ssStoreResult(array $result, int $fileid, int $mtime): void
];

// Store embedding
for ($i = 0; $i < \count($result['embedding']); ++$i) {
$values['v'.$i] = $query->expr()->literal($result['embedding'][$i]);
for ($i = 0; $i < VECTOR_SIZE; ++$i) {
$values["v{$i}"] = $query->expr()->literal($result['embedding'][$i]);
}

$query->insert('memories_ss_vectors')
Expand Down
5 changes: 3 additions & 2 deletions lib/Db/SQL.php
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ public static function materialize(IQueryBuilder $query, string $alias): IQueryB
*
* @param IQueryBuilder $query The query to create the function on
* @param IQueryBuilder $subquery The subquery to use
* @param string $alias The alias to use for the subquery
*/
public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery): IQueryFunction
public static function subquery(IQueryBuilder &$query, IQueryBuilder &$subquery, string $alias = ''): IQueryFunction
{
return $query->createFunction("({$subquery->getSQL()})");
return $query->createFunction("({$subquery->getSQL()}) {$alias}");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
use OCP\Migration\IOutput;
use OCP\Migration\SimpleMigrationStep;

class Version800000Date20240327191449 extends SimpleMigrationStep
class Version900000Date20240327191449 extends SimpleMigrationStep
{
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
*/
public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {}
public function preSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void
{
// Patch doctrine to use float instead of double
\Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, RealFloatType::class);
}

/**
* @param \Closure(): ISchemaWrapper $schemaClosure
Expand All @@ -58,10 +62,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
'notnull' => true,
'length' => 20,
]);
$table->addColumn('lsh', Types::INTEGER, [
'notnull' => true,
'default' => 0,
]);

// Create embedding columns
$size = 768;
Expand All @@ -74,7 +74,6 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op

$table->setPrimaryKey(['id']);
$table->addIndex(['fileid', 'mtime'], 'memories_ss_vec_fileid');
$table->addIndex(['lsh'], 'memories_ss_vec_lsh');
}

return $schema;
Expand All @@ -83,5 +82,26 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
*/
public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void {}
public function postSchemaChange(IOutput $output, \Closure $schemaClosure, array $options): void
{
// Revert doctrine patch
\Doctrine\DBAL\Types\Type::overrideType(Types::FLOAT, \Doctrine\DBAL\Types\FloatType::class);
}
}

class RealFloatType extends \Doctrine\DBAL\Types\FloatType
{
public function getSQLDeclaration(array $column, \Doctrine\DBAL\Platforms\AbstractPlatform $platform)
{
if (preg_match('/mysql|mariadb/i', $platform::class)) {
return 'FLOAT';
}

// https://www.postgresql.org/docs/current/datatype-numeric.html
if (preg_match('/postgres/i', $platform::class)) {
return 'REAL';
}

return parent::getSQLDeclaration($column, $platform);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
use OCP\Migration\IOutput;
use OCP\Migration\SimpleMigrationStep;

class Version800000Date20240327192949 extends SimpleMigrationStep
class Version900000Date20240327192949 extends SimpleMigrationStep
{
/**
* @param \Closure(): ISchemaWrapper $schemaClosure
Expand Down Expand Up @@ -61,6 +61,10 @@ public function changeSchema(IOutput $output, \Closure $schemaClosure, array $op
'notnull' => true,
'default' => 0,
]);
$table->addColumn('word', Types::STRING, [
'notnull' => false,
'length' => 64,
]);

$table->setPrimaryKey(['id']);
$table->addIndex(['fileid'], 'memories_ss_cls_fileid');
Expand Down

0 comments on commit e8ae9e6

Please sign in to comment.