Skip to content

Commit

Permalink
fix keyword index error when storage source is S3 (langgenius#3182)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Apr 8, 2024
1 parent 1f9509a commit 7aa0f4e
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 80 deletions.
52 changes: 39 additions & 13 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
Expand Down Expand Up @@ -657,18 +658,25 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))

for future in futures:
tokens += future.result()

# create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(),
dataset, dataset_document, documents))
create_keyword_thread.start()
if dataset.indexing_technique == 'high_quality':
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))

for future in futures:
tokens += future.result()

create_keyword_thread.join()
indexing_end_at = time.perf_counter()

# update document status to completed
Expand All @@ -682,6 +690,24 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
}
)

def _process_keyword_index(self, flask_app, dataset, dataset_document, documents):
with flask_app.app_context():
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != 'high_quality':
document_ids = [document.metadata['doc_id'] for document in documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.utcnow()
})

db.session.commit()

def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance):
with flask_app.app_context():
Expand All @@ -700,7 +726,7 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d
)

# load index
index_processor.load(dataset, chunk_documents)
index_processor.load(dataset, chunk_documents, with_keywords=False)

document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
Expand Down
142 changes: 75 additions & 67 deletions api/core/rag/datasource/keyword/jieba/jieba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,56 +24,64 @@ def __init__(self, dataset: Dataset):
self._config = KeywordTableConfig()

def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

return self
return self

def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()

keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()

keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

self._save_dataset_keyword_table(keyword_table)

def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())

def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()

ids = [segment.index_node_id for segment in segments]
ids = [segment.index_node_id for segment in segments]

keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

self._save_dataset_keyword_table(keyword_table)
self._save_dataset_keyword_table(keyword_table)

def search(
self, query: str,
Expand Down Expand Up @@ -106,13 +114,15 @@ def search(
return documents

def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)

def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
Expand All @@ -135,33 +145,31 @@ def _save_dataset_keyword_table(self, keyword_table):
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))

def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=20):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
db.session.add(dataset_keyword_table)
db.session.commit()

return {}
return {}

def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
Expand Down

0 comments on commit 7aa0f4e

Please sign in to comment.