Skip to content

Commit

Permalink
fix dataset retrival in dataset mode (langgenius#3334)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Apr 10, 2024
1 parent 3426fc7 commit 83c3d08
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/core/rag/extractor/csv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(

def extract(self) -> list[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
Expand Down
92 changes: 92 additions & 0 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, cast

from flask import Flask, current_app
from langchain.tools import BaseTool

from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
Expand All @@ -17,6 +18,8 @@
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
Expand Down Expand Up @@ -373,3 +376,92 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int,
)

all_documents.extend(documents)

def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
:param dataset_ids: dataset ids
:param retrieve_config: retrieve config
:param return_resource: return resource
:param invoke_from: invoke from
:param hit_callback: hit callback
"""
tools = []
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()

# pass if dataset is not available
if not dataset:
continue

# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue

available_datasets.append(dataset)

if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}

for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model

# get top k
top_k = retrieval_model_config['top_k']

# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")

tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
)

tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
)

tools.append(tool)

return tools

0 comments on commit 83c3d08

Please sign in to comment.