Skip to content

Commit

Permalink
Feat/improve vector database logic (langgenius#1193)
Browse files Browse the repository at this point in the history
Co-authored-by: jyong <jyong@dify.ai>
  • Loading branch information
JohnJyong and JohnJyong authored Sep 18, 2023
1 parent 60e0bbd commit 269a465
Show file tree
Hide file tree
Showing 14 changed files with 463 additions and 46 deletions.
141 changes: 124 additions & 17 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import string
import time
import uuid

import click
from tqdm import tqdm
Expand All @@ -23,7 +24,7 @@
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
import secrets
import base64
Expand Down Expand Up @@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
Expand Down Expand Up @@ -346,7 +353,8 @@ def create_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)

from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
Expand All @@ -364,7 +372,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
Expand All @@ -373,7 +382,8 @@ def create_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue

click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
Expand Down Expand Up @@ -414,7 +424,8 @@ def update_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)

from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
Expand All @@ -435,11 +446,104 @@ def update_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue

click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))


@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0

page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break

page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()

from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')

original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue

click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))


@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
Expand Down Expand Up @@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()

if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
Expand All @@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
offset = i * batch_size
limit = min(batch_size, total_records - offset)

click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')

data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()

if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
Expand All @@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()

account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
Expand All @@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
db.session.commit()

except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')

click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')

pbar.update(len(data_batch))


def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
Expand All @@ -551,4 +657,5 @@ def register_commands(app):
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
8 changes: 8 additions & 0 deletions api/core/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def __init__(self, dataset: Dataset):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError

@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError

@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
Expand All @@ -28,6 +32,10 @@ def text_exists(self, id: str) -> bool:
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError

@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError

@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
Expand Down
32 changes: 32 additions & 0 deletions api/core/index/keyword_table_index/keyword_table_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,32 @@ def create(self, texts: list[Document], **kwargs) -> BaseIndex:

return self

def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()

self._save_dataset_keyword_table(keyword_table)

return self

def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()

Expand Down Expand Up @@ -120,6 +146,12 @@ def delete(self) -> None:
db.session.delete(dataset_keyword_table)
db.session.commit()

def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()

def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
Expand Down
58 changes: 57 additions & 1 deletion api/core/index/vector_index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument


Expand Down Expand Up @@ -110,6 +110,12 @@ def delete_by_ids(self, ids: list[str]) -> None:
for node_id in ids:
vector_store.del_text(node_id)

def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

vector_store.delete()

def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
Expand Down Expand Up @@ -243,3 +249,53 @@ def update_qdrant_dataset(self, dataset: Dataset):
raise e

logging.info(f"Dataset {dataset.id} recreate successfully.")

def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")

dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()

documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()

for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)

documents.append(document)

if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e

logging.info(f"Dataset {dataset.id} recreate successfully.")

def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")

self.delete()

dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()

logging.info(f"Dataset {dataset.id} recreate successfully.")
Loading

0 comments on commit 269a465

Please sign in to comment.