Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/TCL_rag/config.yaml
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
llm:
name: openai
base_url: "https://api.gptsapi.net/v1"
api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
base_url: "xxx"
api_key: "xxx"
model: "gpt-4.1-mini"

embedding:
name: huggingface
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
model_name: "xxx"
model_kwargs:
device: "cuda:0"



store:
name: faiss
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store
folder_path: xxx


bm25:
name: bm25
k: 10
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
data_path: xxx

retriever:
name: vectorstore

reranker:
name: qwen3
model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
model_name_or_path: "xxx"
device_id: "cuda:0"

dataset:
Expand Down
9 changes: 4 additions & 5 deletions examples/TCL_rag/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
vector_store_config=vector_store_config,
bm25_retriever_config=bm25_retriever_config)

result = rag.invoke("毛细管设计规范按照什么标准",k=20)
result = rag.invoke("模块机传感器端子不防呆的改善方案是什么?由哪个部门负责?",k=20)

answer = rag.answer("毛细管设计规范按照什么标准",result)


print(answer)
for i in result:
print(i)
print("-"*100)
4 changes: 2 additions & 2 deletions rag_factory/Retrieval/Retriever/Retriever_BM25.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
from dataclasses import dataclass, field

import uuid
from pydantic import ConfigDict, Field, model_validator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,7 +207,7 @@ def from_texts(
f"与 texts 长度 ({len(texts_list)}) 不匹配"
)
else:
ids_list = [None for _ in texts_list]
ids_list = [str(uuid.uuid4()) for _ in texts_list]

# 预处理文本
logger.info(f"正在预处理 {len(texts_list)} 个文本...")
Expand Down
57 changes: 10 additions & 47 deletions rag_factory/Retrieval/Retriever/Retriever_MultiPath.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:

Note:
- 每个检索器的结果会被转换为RetrievalResult格式
- 支持多种输入格式:Document对象、字典格式、字符串等
- 融合后的结果会将score和rank信息保存在Document的metadata中
- 输入只会是Document对象
- 融合后的结果只返回排序好的Document对象
"""
top_k = kwargs.get('top_k', 10)

Expand All @@ -65,43 +65,12 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
# 转换为RetrievalResult格式
formatted_results = []
for i, doc in enumerate(documents):
if isinstance(doc, Document):
# 如果是Document对象
retrieval_result = RetrievalResult(
document=doc,
score=getattr(doc, 'score', 1.0),
rank=i + 1
)
elif isinstance(doc, dict):
# 如果返回的是字典格式,需要转换为Document对象
content = doc.get('content', '')
metadata = doc.get('metadata', {})
doc_id = doc.get('id')

document = Document(
content=content,
metadata=metadata,
id=doc_id
)

retrieval_result = RetrievalResult(
document=document,
score=doc.get('score', 1.0),
rank=i + 1
)
else:
# 如果是字符串或其他格式,转换为Document对象
document = Document(
content=str(doc),
metadata={},
id=None
)

retrieval_result = RetrievalResult(
document=document,
score=1.0,
rank=i + 1
)
# 输入只会是Document对象
retrieval_result = RetrievalResult(
document=doc,
score=getattr(doc, 'score', 1.0),
rank=i + 1
)
formatted_results.append(retrieval_result)

all_results.append(formatted_results)
Expand All @@ -116,16 +85,10 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:

fused_results = self.fusion_method.fuse(all_results, top_k)

# 转换回Document格式
# 转换回Document格式,只返回排序好的Document对象
documents = []
for result in fused_results:
doc = result.document
# 将score和rank添加到metadata中以便保留
if doc.metadata is None:
doc.metadata = {}
doc.metadata['score'] = result.score
doc.metadata['rank'] = result.rank
documents.append(doc)
documents.append(result.document)

return documents

Expand Down
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ llama-index
llama-index-core
peewee

mineru[core]

rank_bm25
faiss_gpu



# streamlit
# for ocr
PyMuPDF
openai
qwen_vl_utils
transformers==4.51.3
huggingface_hub
Expand All @@ -31,3 +29,6 @@ flash-attn==2.8.0.post2
# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2
accelerate
dashscope
torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128

mineru[core]