Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix keyword index error when storage source is S3 #3182

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
Expand Down Expand Up @@ -657,18 +658,25 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))

for future in futures:
tokens += future.result()

# create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(),
dataset, dataset_document, documents))
create_keyword_thread.start()
if dataset.indexing_technique == 'high_quality':
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))

for future in futures:
tokens += future.result()

create_keyword_thread.join()
indexing_end_at = time.perf_counter()

# update document status to completed
Expand All @@ -682,6 +690,24 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
}
)

def _process_keyword_index(self, flask_app, dataset, dataset_document, documents):
with flask_app.app_context():
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != 'high_quality':
document_ids = [document.metadata['doc_id'] for document in documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.utcnow()
})

db.session.commit()

def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance):
with flask_app.app_context():
Expand All @@ -700,7 +726,7 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d
)

# load index
index_processor.load(dataset, chunk_documents)
index_processor.load(dataset, chunk_documents, with_keywords=False)

document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
Expand Down
142 changes: 75 additions & 67 deletions api/core/rag/datasource/keyword/jieba/jieba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,56 +24,64 @@ def __init__(self, dataset: Dataset):
self._config = KeywordTableConfig()

def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

return self
return self

def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()

keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()

keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)

def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())

def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()

ids = [segment.index_node_id for segment in segments]
ids = [segment.index_node_id for segment in segments]

keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

def search(
self, query: str,
Expand Down Expand Up @@ -106,13 +114,15 @@ def search(
return documents

def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)

def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
Expand All @@ -135,33 +145,31 @@ def _save_dataset_keyword_table(self, keyword_table):
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))

def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=20):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()

return {}
return {}

def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
Expand Down
Loading