Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dataset retrival in dataset mode #3334

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/core/rag/extractor/csv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(

def extract(self) -> list[Document]:
"""Load data into document objects."""
docs = []
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
Expand Down
92 changes: 92 additions & 0 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, cast

from flask import Flask, current_app
from langchain.tools import BaseTool

from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
Expand All @@ -17,6 +18,8 @@
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
Expand Down Expand Up @@ -373,3 +376,92 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int,
)

all_documents.extend(documents)

def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
:param dataset_ids: dataset ids
:param retrieve_config: retrieve config
:param return_resource: return resource
:param invoke_from: invoke from
:param hit_callback: hit callback
"""
tools = []
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()

# pass if dataset is not available
if not dataset:
continue

# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue

available_datasets.append(dataset)

if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}

for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model

# get top k
top_k = retrieval_model_config['top_k']

# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")

tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
)

tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
)

tools.append(tool)

return tools
Loading