Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
01ca681
Merge pull request #4 from DataArcTech/main
Mi221e Aug 6, 2025
6d11774
Merge pull request #5 from DataArcTech/main
Mi221e Aug 6, 2025
ef84055
fix: invoke return all documents
Aug 6, 2025
80cc581
fix: import path
Aug 6, 2025
1042245
fix: import path
Aug 6, 2025
78af535
add: Rerank module
Aug 6, 2025
a9154bd
Merge pull request #6 from DataArcTech/feat/Store
Mi221e Aug 6, 2025
2c649ea
add: MultiPath Retriever
Aug 7, 2025
ab0ceb7
add: Fusion utils function
Aug 7, 2025
efeb01a
fix:invoke k
Aug 7, 2025
6142d17
fix: import path
Aug 7, 2025
fb9b60b
delete Parser_Docling
Aug 7, 2025
e87fb82
Merge branch 'feat/parser' of github.com:DataArcTech/RAG-Factory into…
Aug 7, 2025
08360d8
Add files via upload
snahualimi Aug 8, 2025
e21aa2e
Update requirements.txt
snahualimi Aug 8, 2025
6bcc668
Update readme.md
snahualimi Aug 8, 2025
01b15e0
Update readme.md
snahualimi Aug 8, 2025
2653718
Update readme.md
snahualimi Aug 11, 2025
fb91761
Merge pull request #7 from DataArcTech/feat/parser
Mi221e Aug 11, 2025
9e5d43d
Merge pull request #8 from DataArcTech/main
Mi221e Aug 11, 2025
6650241
Add files via upload
snahualimi Aug 11, 2025
84aa598
Add files via upload
snahualimi Aug 11, 2025
2dddf94
Add files via upload
snahualimi Aug 11, 2025
8a49457
Update requirements.txt
snahualimi Aug 11, 2025
24bf66b
Update readme.md
snahualimi Aug 11, 2025
4ea26f5
Update readme.md
snahualimi Aug 11, 2025
9c2b276
Update readme.md
snahualimi Aug 11, 2025
a52d371
Update fig_recognize.py
snahualimi Aug 11, 2025
25f74eb
Update fig_recognize.py
snahualimi Aug 11, 2025
39e066b
add custom openai_like llm
Aug 11, 2025
8562833
add custom llm abstract base
Aug 11, 2025
608ca71
add custom openai like llm
Aug 11, 2025
f3d1ad2
Update vllm_launch.py
snahualimi Aug 11, 2025
3989c1f
add bm25 faiss_construtor TCL_rag
Aug 11, 2025
b4e22c6
add registry
Aug 11, 2025
ea4f08c
Merge pull request #9 from DataArcTech/feat/Store
Mi221e Aug 11, 2025
a87c27d
remove requirement.txt to root dir
Aug 11, 2025
9027b54
add dots.ocr requirement
Aug 11, 2025
306fcf3
Merge pull request #10 from DataArcTech/feat/parser
Mi221e Aug 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/TCL_rag/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
llm:
name: openai
base_url: "https://api.gptsapi.net/v1"
api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
model: "gpt-4.1-mini"

embedding:
name: huggingface
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
model_kwargs:
device: "cuda:0"



store:
name: faiss
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store


bm25:
name: bm25
k: 10
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json

retriever:
name: vectorstore

reranker:
name: qwen3
model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
device_id: "cuda:0"

dataset:
name: TCL

85 changes: 85 additions & 0 deletions examples/TCL_rag/rag_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
import os

# 添加 RAG-Factory 目录到 Python 路径
rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, rag_factory_path)

from rag_factory.llms import LLMRegistry
from rag_factory.Embed import EmbeddingRegistry
from rag_factory.Store import VectorStoreRegistry
from rag_factory.Retrieval import RetrieverRegistry
from rag_factory.rerankers import RerankerRegistry
from rag_factory.Retrieval import Document
from typing import List
import json


class TCL_RAG:
def __init__(
self,
*,
llm_config=None,
embedding_config=None,
vector_store_config=None,
bm25_retriever_config=None,
retriever_config=None,
reranker_config=None,
):
llm_config = llm_config or {}
embedding_config = embedding_config or {}
vector_store_config = vector_store_config or {}
bm25_retriever_config = bm25_retriever_config or {}
retriever_config = retriever_config or {}
reranker_config = reranker_config or {}
self.llm = LLMRegistry.create(**llm_config)
self.embedding = EmbeddingRegistry.create(**embedding_config)
self.vector_store = VectorStoreRegistry.load(**vector_store_config, embedding=self.embedding)
self.bm25_retriever = RetrieverRegistry.create(**bm25_retriever_config)
self.bm25_retriever = self.bm25_retriever.from_documents(documents=self._load_data(bm25_retriever_config["data_path"]), preprocess_func=self.chinese_preprocessing_func, k=bm25_retriever_config["k"])

self.retriever = RetrieverRegistry.create(**retriever_config, vectorstore=self.vector_store)
self.multi_path_retriever = RetrieverRegistry.create("multipath", retrievers=[self.bm25_retriever, self.retriever])
self.reranker = RerankerRegistry.create(**reranker_config)

def invoke(self, query: str, k: int = None):
return self.multi_path_retriever.invoke(query, top_k=k)

def rerank(self, query: str, documents: List[Document], k: int = None, batch_size: int = 8):
return self.reranker.rerank(query, documents, k, batch_size)

def _load_data(self, data_path: str):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
docs = []
for item in data:
content = item.get("full_content", "")
metadata = {"title": item.get("original_filename", "")}
docs.append(Document(content=content, metadata=metadata))
return docs

def chinese_preprocessing_func(self, text: str) -> str:
import jieba
return " ".join(jieba.cut(text))


def answer(self, query: str, documents: List[Document]):

template = (
"你是一位工业领域的专家。根据以下检索到的材料回答用户问题。"
"如果回答所需信息未在材料中出现,请说明无法找到相关信息。\n\n"
"{context}\n\n"
"用户问题:{question}\n"
"答复:"
)
context = "\n".join([doc.content for doc in documents])
prompt = template.format(question=query, context=context)
messages = [
{"role": "system", "content": "你是一位工业领域的专家。"},
{"role": "user", "content": prompt}
]
return self.llm.chat(messages)




32 changes: 32 additions & 0 deletions examples/TCL_rag/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from rag_flow import TCL_RAG
import yaml

# 加载配置文件
with open('/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/TCL_rag/config.yaml', 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)

llm_config = config['llm']
embedding_config = config['embedding']
reranker_config = config['reranker']
bm25_retriever_config = config['bm25']
retriever_config = config['retriever']
vector_store_config = config['store']




if __name__ == "__main__":

rag = TCL_RAG(llm_config=llm_config,
embedding_config=embedding_config,
reranker_config=reranker_config,
retriever_config=retriever_config,
vector_store_config=vector_store_config,
bm25_retriever_config=bm25_retriever_config)

result = rag.invoke("毛细管设计规范按照什么标准",k=20)

answer = rag.answer("毛细管设计规范按照什么标准",result)


print(answer)
3 changes: 3 additions & 0 deletions examples/bm25/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
retriever:
name: bm25
k: 8
36 changes: 36 additions & 0 deletions examples/bm25/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys
import os

rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, rag_factory_path)

import json
from rag_factory.Retrieval import Document
from rag_factory.Retrieval import RetrieverRegistry

import yaml


def load_data(jsonl_path: str):
with open(jsonl_path, "r", encoding="utf-8") as f:
data = json.load(f)
docs = []
for item in data:
content = item.get("full_content", "")
metadata = {"title": item.get("original_title", "")}
docs.append(Document(content=content, metadata=metadata))
return docs

def chinese_preprocessing_func(text: str) -> str:
import jieba
return " ".join(jieba.cut(text))

if __name__ == "__main__":
docs = load_data("/data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json")
with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/bm25/config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)

bm25_retriever = RetrieverRegistry.create(**config["retriever"])
bm25_retriever = bm25_retriever.from_documents(documents=docs, preprocess_func=chinese_preprocessing_func, k=config["retriever"]["k"])

print(bm25_retriever.invoke("什么是TCL?"))
14 changes: 14 additions & 0 deletions examples/faiss_construct/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
store:
name: faiss # 数据库
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store # 保存路径


embedding:
name: huggingface # 嵌入模型
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" # 模型路径
model_kwargs:
device: "cuda:1" # 设备

dataset:
name: TCL
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
43 changes: 43 additions & 0 deletions examples/faiss_construct/faiss_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys
import os

# 添加 RAG-Factory 目录到 Python 路径
rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, rag_factory_path)

from rag_factory.Store import VectorStoreRegistry
from rag_factory.Embed import EmbeddingRegistry
import yaml
from rag_factory.Retrieval import Document
import json


with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/faiss_construct/config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)

store_config = config["store"]
embedding_config = config["embedding"]
dataset_config = config["dataset"]["data_path"]
embedding = EmbeddingRegistry.create(**embedding_config)
store = VectorStoreRegistry.create(**store_config, embedding=embedding)


if __name__ == "__main__":

# 读取数据
with open(dataset_config, "r", encoding="utf-8") as f:
docs = []
data = json.load(f)
for item in data:
full_content = item.get("full_content", "")
metadata = {
"title": item.get("original_filename"),
}

docs.append(Document(content=full_content, metadata=metadata))

# 创建向量库
vectorstore = store.from_documents(docs, embedding=embedding)

# 保存到本地
vectorstore.save_local(store_config["folder_path"])
9 changes: 5 additions & 4 deletions rag_factory/Embed/Embedding_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from dataclasses import dataclass
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List

class Embeddings(ABC):
"""嵌入接口"""

@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.

Args:
Expand All @@ -19,7 +20,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
pass

@abstractmethod
def embed_query(self, text: str) -> list[float]:
def embed_query(self, text: str) -> List[float]:
"""Embed query text.

Args:
Expand All @@ -30,7 +31,7 @@ def embed_query(self, text: str) -> list[float]:
"""
pass

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.

Args:
Expand All @@ -43,7 +44,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
ThreadPoolExecutor(), self.embed_documents, texts
)

async def aembed_query(self, text: str) -> list[float]:
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.

Args:
Expand Down
3 changes: 2 additions & 1 deletion rag_factory/Embed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .Embedding_Base import Embeddings
from .Embedding_Huggingface import HuggingFaceEmbeddings
from .registry import EmbeddingRegistry

__all__ = ["Embeddings", "HuggingFaceEmbeddings"]
__all__ = ["Embeddings", "HuggingFaceEmbeddings", "EmbeddingRegistry"]
79 changes: 79 additions & 0 deletions rag_factory/Embed/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Dict, Type, Any, Optional, List
import logging
from .Embedding_Huggingface import HuggingFaceEmbeddings
from .Embedding_Base import Embeddings

class EmbeddingRegistry:
"""嵌入模型注册器,用于管理和创建不同类型的嵌入模型"""
_embeddings: Dict[str, Type[Embeddings]] = {}

@classmethod
def register(cls, name: str, embedding_class: Type[Embeddings]):
"""注册嵌入模型类

Args:
name: 模型名称
embedding_class: 嵌入模型类
"""
cls._embeddings[name] = embedding_class

@classmethod
def create(cls, name: str, **kwargs) -> Embeddings:
"""获取嵌入模型实例

Args:
name: 模型名称
**kwargs: 模型初始化参数

Returns:
嵌入模型实例

Raises:
ValueError: 当模型名称不存在时
"""
if name not in cls._embeddings:
available_embeddings = list(cls._embeddings.keys())
raise ValueError(f"嵌入模型 '{name}' 未注册。可用的模型: {available_embeddings}")

embedding_class = cls._embeddings[name]
return embedding_class(**kwargs)

@classmethod
def list_embeddings(cls) -> List[str]:
"""列出所有已注册的嵌入模型名称

Returns:
已注册的模型名称列表
"""
return list(cls._embeddings.keys())

@classmethod
def is_registered(cls, name: str) -> bool:
"""检查模型是否已注册

Args:
name: 模型名称

Returns:
如果已注册返回True,否则返回False
"""
return name in cls._embeddings

@classmethod
def unregister(cls, name: str) -> bool:
"""取消注册模型

Args:
name: 模型名称

Returns:
成功取消注册返回True,模型不存在返回False
"""
if name in cls._embeddings:
del cls._embeddings[name]
return True
return False


# 注册默认的嵌入模型
EmbeddingRegistry.register("huggingface", HuggingFaceEmbeddings)
Loading