From 269a465fc43e4da1a77938cc66a80d1ac873cb1a Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 18 Sep 2023 18:15:41 +0800 Subject: [PATCH] Feat/improve vector database logic (#1193) Co-authored-by: jyong --- api/commands.py | 141 +++++++++++++++--- api/core/index/base.py | 8 + .../keyword_table_index.py | 32 ++++ api/core/index/vector_index/base.py | 58 ++++++- .../index/vector_index/milvus_vector_index.py | 13 ++ api/core/index/vector_index/qdrant.py | 40 ++++- .../index/vector_index/qdrant_vector_index.py | 78 +++++++--- .../vector_index/weaviate_vector_index.py | 14 ++ api/core/tool/dataset_retriever_tool.py | 6 +- api/core/vector_store/qdrant_vector_store.py | 5 + ...fb077b04_add_dataset_collection_binding.py | 47 ++++++ api/models/dataset.py | 18 +++ api/services/dataset_service.py | 44 +++++- api/services/hit_testing_service.py | 5 +- 14 files changed, 463 insertions(+), 46 deletions(-) create mode 100644 api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py diff --git a/api/commands.py b/api/commands.py index 98801153412cad..8b6f5455f1a300 100644 --- a/api/commands.py +++ b/api/commands.py @@ -4,6 +4,7 @@ import random import string import time +import uuid import click from tqdm import tqdm @@ -23,7 +24,7 @@ from extensions.ext_database import db from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant, TenantAccountJoin -from models.dataset import Dataset, DatasetQuery, Document +from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding from models.model import Account, AppModelConfig, App import secrets import base64 @@ -239,7 +240,13 @@ def clean_unused_dataset_indexes(): kw_index = IndexBuilder.get_index(dataset, 'economy') # delete from vector index if vector_index: - vector_index.delete() + if dataset.collection_binding_id: + vector_index.delete_by_group_id(dataset.id) + else: + if dataset.collection_binding_id: + vector_index.delete_by_group_id(dataset.id) + else: + vector_index.delete() kw_index.delete() # update document update_params = { @@ -346,7 +353,8 @@ def create_qdrant_indexes(): is_valid=True, ) model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) + embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", + model_provider=model_provider) embeddings = CacheEmbedding(embedding_model) from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig @@ -364,7 +372,8 @@ def create_qdrant_indexes(): index.create_qdrant_dataset(dataset) index_struct = { "type": 'qdrant', - "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} + "vector_store": { + "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} } dataset.index_struct = json.dumps(index_struct) db.session.commit() @@ -373,7 +382,8 @@ def create_qdrant_indexes(): click.echo('passed.') except Exception as e: click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) continue click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) @@ -414,7 +424,8 @@ def update_qdrant_indexes(): is_valid=True, ) model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) + embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", + model_provider=model_provider) embeddings = CacheEmbedding(embedding_model) from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig @@ -435,11 +446,104 @@ def update_qdrant_indexes(): click.echo('passed.') except Exception as e: click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) continue click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) + +@click.command('normalization-collections', help='restore all collections in one') +def normalization_collections(): + click.echo(click.style('Start normalization collections.', fg='green')) + normalization_count = 0 + + page = 1 + while True: + try: + datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ + .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + except NotFound: + break + + page += 1 + for dataset in datasets: + if not dataset.collection_binding_id: + try: + click.echo('restore dataset index: {}'.format(dataset.id)) + try: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except Exception: + provider = Provider( + id='provider_id', + tenant_id=dataset.tenant_id, + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': 'TEST'}), + is_valid=True, + ) + model_provider = OpenAIProvider(provider=provider) + embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", + model_provider=model_provider) + embeddings = CacheEmbedding(embedding_model) + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, + DatasetCollectionBinding.model_name == embedding_model.name). \ + order_by(DatasetCollectionBinding.created_at). \ + first() + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=embedding_model.model_provider.provider_name, + model_name=embedding_model.name, + collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' + ) + db.session.add(dataset_collection_binding) + db.session.commit() + + from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + + index = QdrantVectorIndex( + dataset=dataset, + config=QdrantConfig( + endpoint=current_app.config.get('QDRANT_URL'), + api_key=current_app.config.get('QDRANT_API_KEY'), + root_path=current_app.root_path + ), + embeddings=embeddings + ) + if index: + index.restore_dataset_in_one(dataset, dataset_collection_binding) + else: + click.echo('passed.') + + original_index = QdrantVectorIndex( + dataset=dataset, + config=QdrantConfig( + endpoint=current_app.config.get('QDRANT_URL'), + api_key=current_app.config.get('QDRANT_API_KEY'), + root_path=current_app.root_path + ), + embeddings=embeddings + ) + if original_index: + original_index.delete_original_collection(dataset, dataset_collection_binding) + normalization_count += 1 + else: + click.echo('passed.') + except Exception as e: + click.echo( + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) + continue + + click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) + + @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") def update_app_model_configs(batch_size): @@ -473,7 +577,7 @@ def update_app_model_configs(batch_size): .join(App, App.app_model_config_id == AppModelConfig.id) \ .filter(App.mode == 'completion') \ .count() - + if total_records == 0: click.secho("No data to migrate.", fg='green') return @@ -485,14 +589,14 @@ def update_app_model_configs(batch_size): offset = i * batch_size limit = min(batch_size, total_records - offset) - click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green') - + click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green') + data_batch = db.session.query(AppModelConfig) \ .join(App, App.app_model_config_id == AppModelConfig.id) \ .filter(App.mode == 'completion') \ .order_by(App.created_at) \ .offset(offset).limit(limit).all() - + if not data_batch: click.secho("No more data to migrate.", fg='green') break @@ -512,7 +616,7 @@ def update_app_model_configs(batch_size): app_data = db.session.query(App) \ .filter(App.id == data.app_id) \ .one() - + account_data = db.session.query(Account) \ .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ .filter(TenantAccountJoin.role == 'owner') \ @@ -534,13 +638,15 @@ def update_app_model_configs(batch_size): db.session.commit() except Exception as e: - click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red') + click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", + fg='red') continue - - click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green') - + + click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') + pbar.update(len(data_batch)) + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) @@ -551,4 +657,5 @@ def register_commands(app): app.cli.add_command(clean_unused_dataset_indexes) app.cli.add_command(create_qdrant_indexes) app.cli.add_command(update_qdrant_indexes) - app.cli.add_command(update_app_model_configs) \ No newline at end of file + app.cli.add_command(update_app_model_configs) + app.cli.add_command(normalization_collections) diff --git a/api/core/index/base.py b/api/core/index/base.py index a8755f9182bd4f..025c94bbe56baf 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -16,6 +16,10 @@ def __init__(self, dataset: Dataset): def create(self, texts: list[Document], **kwargs) -> BaseIndex: raise NotImplementedError + @abstractmethod + def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + raise NotImplementedError + @abstractmethod def add_texts(self, texts: list[Document], **kwargs): raise NotImplementedError @@ -28,6 +32,10 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + @abstractmethod + def delete_by_group_id(self, group_id: str) -> None: + raise NotImplementedError + @abstractmethod def delete_by_document_id(self, document_id: str): raise NotImplementedError diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 7b00e9825f50af..38f68c7a5057bc 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -46,6 +46,32 @@ def create(self, texts: list[Document], **kwargs) -> BaseIndex: return self + def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + keyword_table_handler = JiebaKeywordTableHandler() + keyword_table = {} + for text in texts: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table=json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + ) + db.session.add(dataset_keyword_table) + db.session.commit() + + self._save_dataset_keyword_table(keyword_table) + + return self + def add_texts(self, texts: list[Document], **kwargs): keyword_table_handler = JiebaKeywordTableHandler() @@ -120,6 +146,12 @@ def delete(self) -> None: db.session.delete(dataset_keyword_table) db.session.commit() + def delete_by_group_id(self, group_id: str) -> None: + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + db.session.delete(dataset_keyword_table) + db.session.commit() + def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { '__type__': 'keyword_table', diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index 98b7ea6b6d41cc..1e59135f37ed13 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -10,7 +10,7 @@ from core.index.base import BaseIndex from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding from models.dataset import Document as DatasetDocument @@ -110,6 +110,12 @@ def delete_by_ids(self, ids: list[str]) -> None: for node_id in ids: vector_store.del_text(node_id) + def delete_by_group_id(self, group_id: str) -> None: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + vector_store.delete() + def delete(self) -> None: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) @@ -243,3 +249,53 @@ def update_qdrant_dataset(self, dataset: Dataset): raise e logging.info(f"Dataset {dataset.id} recreate successfully.") + + def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): + logging.info(f"restore dataset in_one,_dataset {dataset.id}") + + dataset_documents = db.session.query(DatasetDocument).filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == 'completed', + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).all() + + documents = [] + for dataset_document in dataset_documents: + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + documents.append(document) + + if documents: + try: + self.create_with_collection_name(documents, dataset_collection_binding.collection_name) + except Exception as e: + raise e + + logging.info(f"Dataset {dataset.id} recreate successfully.") + + def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): + logging.info(f"delete original collection: {dataset.id}") + + self.delete() + + dataset.collection_binding_id = dataset_collection_binding.id + db.session.add(dataset) + db.session.commit() + + logging.info(f"Dataset {dataset.id} recreate successfully.") diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index a94d46ddd36f6f..abf57f55297c1a 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -69,6 +69,19 @@ def create(self, texts: list[Document], **kwargs) -> BaseIndex: return self + def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + uuids = self._get_uuids(texts) + self._vector_store = WeaviateVectorStore.from_documents( + texts, + self._embeddings, + client=self._client, + index_name=collection_name, + uuids=uuids, + by_text=False + ) + + return self + def _get_vector_store(self) -> VectorStore: """Only for created index.""" if self._vector_store: diff --git a/api/core/index/vector_index/qdrant.py b/api/core/index/vector_index/qdrant.py index c14a77f076474d..56f7a2ce340f61 100644 --- a/api/core/index/vector_index/qdrant.py +++ b/api/core/index/vector_index/qdrant.py @@ -28,6 +28,7 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +from qdrant_client.http.models import PayloadSchemaType if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -84,6 +85,7 @@ class Qdrant(VectorStore): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" + GROUP_KEY = "group_id" VECTOR_NAME = None def __init__( @@ -93,9 +95,12 @@ def __init__( embeddings: Optional[Embeddings] = None, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, + group_payload_key: str = GROUP_KEY, + group_id: str = None, distance_strategy: str = "COSINE", vector_name: Optional[str] = VECTOR_NAME, embedding_function: Optional[Callable] = None, # deprecated + is_new_collection: bool = False ): """Initialize with necessary components.""" try: @@ -129,7 +134,10 @@ def __init__( self.collection_name = collection_name self.content_payload_key = content_payload_key or self.CONTENT_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY + self.group_payload_key = group_payload_key or self.GROUP_KEY self.vector_name = vector_name or self.VECTOR_NAME + self.group_id = group_id + self.is_new_collection= is_new_collection if embedding_function is not None: warnings.warn( @@ -170,6 +178,8 @@ def add_texts( batch_size: How many vectors upload per-request. Default: 64 + group_id: + collection group Returns: List of ids from adding the texts into the vectorstore. @@ -182,7 +192,11 @@ def add_texts( collection_name=self.collection_name, points=points, **kwargs ) added_ids.extend(batch_ids) - + # if is new collection, create payload index on group_id + if self.is_new_collection: + self.client.create_payload_index(self.collection_name, self.group_payload_key, + field_schema=PayloadSchemaType.KEYWORD, + field_type=PayloadSchemaType.KEYWORD) return added_ids @sync_call_fallback @@ -970,6 +984,8 @@ def from_texts( distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, + group_payload_key: str = GROUP_KEY, + group_id: str = None, vector_name: Optional[str] = VECTOR_NAME, batch_size: int = 64, shard_number: Optional[int] = None, @@ -1034,6 +1050,11 @@ def from_texts( metadata_payload_key: A payload key used to store the metadata of the document. Default: "metadata" + group_payload_key: + A payload key used to store the content of the document. + Default: "group_id" + group_id: + collection group id vector_name: Name of the vector to be used internally in Qdrant. Default: None @@ -1107,6 +1128,8 @@ def from_texts( distance_func, content_payload_key, metadata_payload_key, + group_payload_key, + group_id, vector_name, shard_number, replication_factor, @@ -1321,6 +1344,8 @@ def _construct_instance( distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, + group_payload_key: str = GROUP_KEY, + group_id: str = None, vector_name: Optional[str] = VECTOR_NAME, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -1350,6 +1375,7 @@ def _construct_instance( vector_size = len(partial_embeddings[0]) collection_name = collection_name or uuid.uuid4().hex distance_func = distance_func.upper() + is_new_collection = False client = qdrant_client.QdrantClient( location=location, url=url, @@ -1454,6 +1480,7 @@ def _construct_instance( init_from=init_from, timeout=timeout, # type: ignore[arg-type] ) + is_new_collection = True qdrant = cls( client=client, collection_name=collection_name, @@ -1462,6 +1489,9 @@ def _construct_instance( metadata_payload_key=metadata_payload_key, distance_strategy=distance_func, vector_name=vector_name, + group_id=group_id, + group_payload_key=group_payload_key, + is_new_collection=is_new_collection ) return qdrant @@ -1516,6 +1546,8 @@ def _build_payloads( metadatas: Optional[List[dict]], content_payload_key: str, metadata_payload_key: str, + group_id: str, + group_payload_key: str ) -> List[dict]: payloads = [] for i, text in enumerate(texts): @@ -1529,6 +1561,7 @@ def _build_payloads( { content_payload_key: text, metadata_payload_key: metadata, + group_payload_key: group_id } ) @@ -1578,7 +1611,7 @@ def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: else: out.append( rest.FieldCondition( - key=f"{self.metadata_payload_key}.{key}", + key=key, match=rest.MatchValue(value=value), ) ) @@ -1654,6 +1687,7 @@ def _generate_rest_batches( metadatas: Optional[List[dict]] = None, ids: Optional[Sequence[str]] = None, batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest @@ -1684,6 +1718,8 @@ def _generate_rest_batches( batch_metadatas, self.content_payload_key, self.metadata_payload_key, + self.group_id, + self.group_payload_key ), ) ] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index 4814837c8f0fb3..2be77609a654b5 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -6,18 +6,20 @@ from langchain.schema import Document, BaseRetriever from langchain.vectorstores import VectorStore from pydantic import BaseModel +from qdrant_client.http.models import HnswConfigDiff from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.qdrant_vector_store import QdrantVectorStore -from models.dataset import Dataset +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding class QdrantConfig(BaseModel): endpoint: str api_key: Optional[str] root_path: Optional[str] - + def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith('path:'): path = self.endpoint.replace('path:', '') @@ -43,16 +45,21 @@ def get_type(self) -> str: return 'qdrant' def get_index_name(self, dataset: Dataset) -> str: - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - class_prefix += '_Node' - - return class_prefix + if dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + return dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + if self.dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + return class_prefix - dataset_id = dataset.id - return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + dataset_id = dataset.id + return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' def to_index_struct(self) -> dict: return { @@ -68,6 +75,27 @@ def create(self, texts: list[Document], **kwargs) -> BaseIndex: collection_name=self.get_index_name(self.dataset), ids=uuids, content_payload_key='page_content', + group_id=self.dataset.id, + group_payload_key='group_id', + hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False), + **self._client_config.to_qdrant_params() + ) + + return self + + def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + uuids = self._get_uuids(texts) + self._vector_store = QdrantVectorStore.from_documents( + texts, + self._embeddings, + collection_name=collection_name, + ids=uuids, + content_payload_key='page_content', + group_id=self.dataset.id, + group_payload_key='group_id', + hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False), **self._client_config.to_qdrant_params() ) @@ -78,8 +106,6 @@ def _get_vector_store(self) -> VectorStore: if self._vector_store: return self._vector_store attributes = ['doc_id', 'dataset_id', 'document_id'] - if self._is_origin(): - attributes = ['doc_id'] client = qdrant_client.QdrantClient( **self._client_config.to_qdrant_params() ) @@ -88,16 +114,15 @@ def _get_vector_store(self) -> VectorStore: client=client, collection_name=self.get_index_name(self.dataset), embeddings=self._embeddings, - content_payload_key='page_content' + content_payload_key='page_content', + group_id=self.dataset.id, + group_payload_key='group_id' ) def _get_vector_store_class(self) -> type: return QdrantVectorStore def delete_by_document_id(self, document_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) @@ -114,9 +139,6 @@ def delete_by_document_id(self, document_id: str): )) def delete_by_ids(self, ids: list[str]) -> None: - if self._is_origin(): - self.recreate_dataset(self.dataset) - return vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) @@ -132,6 +154,22 @@ def delete_by_ids(self, ids: list[str]) -> None: ], )) + def delete_by_group_id(self, group_id: str) -> None: + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=group_id), + ), + ], + )) + + def _is_origin(self): if self.dataset.index_struct_dict: class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index df12dd9c534464..1432a707079e3f 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -91,6 +91,20 @@ def create(self, texts: list[Document], **kwargs) -> BaseIndex: return self + def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: + uuids = self._get_uuids(texts) + self._vector_store = WeaviateVectorStore.from_documents( + texts, + self._embeddings, + client=self._client, + index_name=self.get_index_name(self.dataset), + uuids=uuids, + by_text=False + ) + + return self + + def _get_vector_store(self) -> VectorStore: """Only for created index.""" if self._vector_store: diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 4c9c9b625d5873..d90636e71744bd 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool): return_resource: str retriever_from: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description @@ -94,7 +93,10 @@ def _run(self, query: str) -> str: query, search_type='similarity_score_threshold', search_kwargs={ - 'k': self.k + 'k': self.k, + 'filter': { + 'group_id': [dataset.id] + } } ) else: diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py index 01c49918689d58..dc92b8cb249563 100644 --- a/api/core/vector_store/qdrant_vector_store.py +++ b/api/core/vector_store/qdrant_vector_store.py @@ -46,6 +46,11 @@ def delete(self): self.client.delete_collection(collection_name=self.collection_name) + def delete_group(self): + self._reload_if_needed() + + self.client.delete_collection(collection_name=self.collection_name) + @classmethod def _document_from_scored_point( cls, diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py new file mode 100644 index 00000000000000..d37570e1bd4b5c --- /dev/null +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -0,0 +1,47 @@ +"""add_dataset_collection_binding + +Revision ID: 6e2cfb077b04 +Revises: 77e83833755c +Create Date: 2023-09-13 22:16:48.027810 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6e2cfb077b04' +down_revision = '77e83833755c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_collection_bindings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('collection_name', sa.String(length=64), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey') + ) + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('collection_binding_id') + + with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: + batch_op.drop_index('provider_model_name_idx') + + op.drop_table('dataset_collection_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 338eb173cf32a7..a9a33cc1a7b470 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -38,6 +38,8 @@ class Dataset(db.Model): server_default=db.text('CURRENT_TIMESTAMP(0)')) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) + collection_binding_id = db.Column(UUID, nullable=True) + @property def dataset_keyword_table(self): @@ -445,3 +447,19 @@ def set_embedding(self, embedding_data: list[float]): def get_embedding(self) -> list[float]: return pickle.loads(self.embedding) + + +class DatasetCollectionBinding(db.Model): + __tablename__ = 'dataset_collection_bindings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), + db.Index('provider_model_name_idx', 'provider_name', 'model_name') + + ) + + id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + provider_name = db.Column(db.String(40), nullable=False) + model_name = db.Column(db.String(40), nullable=False) + collection_name = db.Column(db.String(64), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index a40476a0e48a12..78686112a58137 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -20,7 +20,8 @@ from extensions.ext_database import db from libs import helper from models.account import Account -from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment +from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \ + DatasetCollectionBinding from models.model import UploadFile from models.source import DataSourceBinding from services.errors.account import NoPermissionError @@ -147,6 +148,7 @@ def update_dataset(dataset_id, data, user): action = 'remove' filtered_data['embedding_model'] = None filtered_data['embedding_model_provider'] = None + filtered_data['collection_binding_id'] = None elif data['indexing_technique'] == 'high_quality': action = 'add' # get embedding model setting @@ -156,6 +158,11 @@ def update_dataset(dataset_id, data, user): ) filtered_data['embedding_model'] = embedding_model.name filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.model_provider.provider_name, + embedding_model.name + ) + filtered_data['collection_binding_id'] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( f"No Embedding Model available. Please configure a valid provider " @@ -464,7 +471,11 @@ def save_document_with_dataset_id(dataset: Dataset, document_data: dict, ) dataset.embedding_model = embedding_model.name dataset.embedding_model_provider = embedding_model.model_provider.provider_name - + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.model_provider.provider_name, + embedding_model.name + ) + dataset.collection_binding_id = dataset_collection_binding.id documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) @@ -720,10 +731,16 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun if total_count > tenant_document_count: raise ValueError(f"All your documents have overed limit {tenant_document_count}.") embedding_model = None + dataset_collection_binding_id = None if document_data['indexing_technique'] == 'high_quality': embedding_model = ModelFactory.get_embedding_model( tenant_id=tenant_id ) + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.model_provider.provider_name, + embedding_model.name + ) + dataset_collection_binding_id = dataset_collection_binding.id # save dataset dataset = Dataset( tenant_id=tenant_id, @@ -732,7 +749,8 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun indexing_technique=document_data["indexing_technique"], created_by=account.id, embedding_model=embedding_model.name if embedding_model else None, - embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None + embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, + collection_binding_id=dataset_collection_binding_id ) db.session.add(dataset) @@ -1069,3 +1087,23 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) db.session.delete(segment) db.session.commit() + + +class DatasetCollectionBindingService: + @classmethod + def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name). \ + order_by(DatasetCollectionBinding.created_at). \ + first() + + if not dataset_collection_binding: + dataset_collection_binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' + ) + db.session.add(dataset_collection_binding) + db.session.flush() + return dataset_collection_binding diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index b6e622a90b9636..063292969cd55a 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -47,7 +47,10 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 1 query, search_type='similarity_score_threshold', search_kwargs={ - 'k': 10 + 'k': 10, + 'filter': { + 'group_id': [dataset.id] + } } ) end = time.perf_counter()