-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support relyt vector database (#3367)
Co-authored-by: jingsi <jingsi@leadincloud.com>
- Loading branch information
1 parent
92f8c40
commit 3339783
Showing
8 changed files
with
225 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import logging | ||
from typing import Any | ||
|
||
from pgvecto_rs.sdk import PGVectoRs, Record | ||
from pydantic import BaseModel, root_validator | ||
from sqlalchemy import text as sql_text | ||
from sqlalchemy.orm import Session | ||
|
||
from core.rag.datasource.vdb.vector_base import BaseVector | ||
from core.rag.models.document import Document | ||
from extensions.ext_redis import redis_client | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class RelytConfig(BaseModel): | ||
host: str | ||
port: int | ||
user: str | ||
password: str | ||
database: str | ||
|
||
@root_validator() | ||
def validate_config(cls, values: dict) -> dict: | ||
if not values['host']: | ||
raise ValueError("config RELYT_HOST is required") | ||
if not values['port']: | ||
raise ValueError("config RELYT_PORT is required") | ||
if not values['user']: | ||
raise ValueError("config RELYT_USER is required") | ||
if not values['password']: | ||
raise ValueError("config RELYT_PASSWORD is required") | ||
if not values['database']: | ||
raise ValueError("config RELYT_DATABASE is required") | ||
return values | ||
|
||
|
||
class RelytVector(BaseVector): | ||
|
||
def __init__(self, collection_name: str, config: RelytConfig, dim: int): | ||
super().__init__(collection_name) | ||
self._client_config = config | ||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" | ||
self._client = PGVectoRs( | ||
db_url=self._url, | ||
collection_name=self._collection_name, | ||
dimension=dim | ||
) | ||
self._fields = [] | ||
|
||
def get_type(self) -> str: | ||
return 'relyt' | ||
|
||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||
index_params = {} | ||
metadatas = [d.metadata for d in texts] | ||
self.create_collection(len(embeddings[0])) | ||
self.add_texts(texts, embeddings) | ||
|
||
def create_collection(self, dimension: int): | ||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | ||
with redis_client.lock(lock_name, timeout=20): | ||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | ||
if redis_client.get(collection_exist_cache_key): | ||
return | ||
index_name = f"{self._collection_name}_embedding_index" | ||
with Session(self._client._engine) as session: | ||
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}") | ||
session.execute(drop_statement) | ||
create_statement = sql_text(f""" | ||
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} ( | ||
id UUID PRIMARY KEY, | ||
text TEXT NOT NULL, | ||
meta JSONB NOT NULL, | ||
embedding vector({dimension}) NOT NULL | ||
) using heap; | ||
""") | ||
session.execute(create_statement) | ||
index_statement = sql_text(f""" | ||
CREATE INDEX {index_name} | ||
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops) | ||
WITH (options = $$ | ||
optimizing.optimizing_threads = 30 | ||
segment.max_growing_segment_size = 2000 | ||
segment.max_sealed_segment_size = 30000000 | ||
[indexing.hnsw] | ||
m=30 | ||
ef_construction=500 | ||
$$); | ||
""") | ||
session.execute(index_statement) | ||
session.commit() | ||
redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||
|
||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)] | ||
pks = [str(r.id) for r in records] | ||
self._client.insert(records) | ||
return pks | ||
|
||
def delete_by_document_id(self, document_id: str): | ||
ids = self.get_ids_by_metadata_field('document_id', document_id) | ||
if ids: | ||
self._client.delete_by_ids(ids) | ||
|
||
def get_ids_by_metadata_field(self, key: str, value: str): | ||
result = None | ||
with Session(self._client._engine) as session: | ||
select_statement = sql_text( | ||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; " | ||
) | ||
result = session.execute(select_statement).fetchall() | ||
if result: | ||
return [item[0] for item in result] | ||
else: | ||
return None | ||
|
||
def delete_by_metadata_field(self, key: str, value: str): | ||
|
||
ids = self.get_ids_by_metadata_field(key, value) | ||
if ids: | ||
self._client.delete_by_ids(ids) | ||
|
||
def delete_by_ids(self, doc_ids: list[str]) -> None: | ||
with Session(self._client._engine) as session: | ||
select_statement = sql_text( | ||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); " | ||
) | ||
result = session.execute(select_statement).fetchall() | ||
if result: | ||
ids = [item[0] for item in result] | ||
self._client.delete_by_ids(ids) | ||
|
||
def delete(self) -> None: | ||
with Session(self._client._engine) as session: | ||
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")) | ||
session.commit() | ||
|
||
def text_exists(self, id: str) -> bool: | ||
with Session(self._client._engine) as session: | ||
select_statement = sql_text( | ||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " | ||
) | ||
result = session.execute(select_statement).fetchall() | ||
return len(result) > 0 | ||
|
||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||
from pgvecto_rs.sdk import filters | ||
filter_condition = filters.meta_contains(kwargs.get('filter')) | ||
results = self._client.search( | ||
top_k=int(kwargs.get('top_k')), | ||
embedding=query_vector, | ||
filter=filter_condition | ||
) | ||
|
||
# Organize results. | ||
docs = [] | ||
for record, dis in results: | ||
metadata = record.meta | ||
metadata['score'] = dis | ||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 | ||
if dis > score_threshold: | ||
doc = Document(page_content=record.text, | ||
metadata=metadata) | ||
docs.append(doc) | ||
return docs | ||
|
||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||
# milvus/zilliz/relyt doesn't support bm25 search | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,3 +79,4 @@ azure-storage-blob==12.9.0 | |
azure-identity==1.15.0 | ||
lxml==5.1.0 | ||
xlrd~=2.0.1 | ||
pgvecto-rs==0.1.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters