Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 63 additions & 43 deletions rag_factory/Retrieval/Retriever/Retriever_BM25.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from pydantic import ConfigDict, Field, model_validator

logger = logging.getLogger(__name__)

import numpy as np
from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document


def default_preprocessing_func(text: str) -> List[str]:
"""默认的文本预处理函数
"""默认的文本预处理函数,仅在英文文本上有效

Args:
text: 输入文本
Expand All @@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]:
return text.split()


def chinese_preprocessing_func(text: str) -> List[str]:
"""中文文本预处理函数

Args:
text: 输入的中文文本

Returns:
分词后的词语列表
"""
try:
import jieba
return list(jieba.cut(text))
except ImportError:
logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba")
return text.split()


class BM25Retriever(BaseRetriever):
"""BM25 检索器实现

基于 BM25 算法的文档检索器。
使用 rank_bm25 库实现高效的 BM25 搜索。

注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档,
但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。
对于频繁更新的场景,建议使用 VectorStoreRetriever。

"""
BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。

该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。
适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。

主要特性:
- 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。
- 支持自定义分词/预处理函数,适配不同语言和分词需求。
- 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。
- 可获取检索分数、top-k 文档及分数、检索器配置信息等。
- 兼容异步文档添加/删除,便于大规模数据处理。
- 通过 Pydantic 校验参数,保证配置安全。

主要参数:
vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。
docs (List[Document]): 当前检索器持有的文档对象列表。
k (int): 默认返回的相关文档数量。
preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。
bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。

核心方法:
- from_texts/from_documents: 从原始文本或 Document 构建检索器。
- _get_relevant_documents: 检索与查询最相关的前 k 个文档。
- get_scores: 获取查询对所有文档的 BM25 分数。
- get_top_k_with_scores: 获取 top-k 文档及其分数。
- add_documents/delete_documents: 动态增删文档并重建索引。
- get_bm25_info: 获取检索器配置信息和统计。
- update_k: 动态调整返回文档数量。

性能注意事项:
- 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。
- 文档量较大或频繁更新时,建议使用向量检索方案。
- 支持异步操作,便于大规模数据处理。

典型用法:
>>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3)
>>> results = retriever._get_relevant_documents("查询语句")
>>> retriever.add_documents([Document(content="新文档")])
>>> retriever.delete_documents(ids=["doc_id"])
>>> info = retriever.get_bm25_info()

Attributes:
vectorizer: BM25 向量化器实例
docs: 文档列表
Expand Down Expand Up @@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Returns:
验证后的值
"""
k = values.get("k", 4)
k = values.get("k", 5)
if k <= 0:
raise ValueError(f"k 必须大于 0,当前值: {k}")

Expand Down Expand Up @@ -259,46 +277,48 @@ def from_documents(
)

def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
"""获取与查询相关的文档
"""获取与查询相关的前k个文档

Args:
query: 查询字符串
**kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量

Returns:
相关文档列表

Raises:
ValueError: 如果向量化器未初始化
"""
if self.vectorizer is None:
raise ValueError("BM25 向量化器未初始化")

if not self.docs:
logger.warning("文档列表为空,返回空结果")
return []

# 获取返回文档数量
k = kwargs.get('k', self.k)
k = min(k, len(self.docs)) # 确保不超过总文档数

try:
# 预处理查询
processed_query = self.preprocess_func(query)
logger.debug(f"预处理后的查询: {processed_query}")

# 获取所有文档的分数
scores = self.vectorizer.get_scores(processed_query)
# 获取分数最高的前k个文档索引

# 获取相关文档
relevant_docs = self.vectorizer.get_top_n(
processed_query, self.docs, n=k
)

logger.debug(f"找到 {len(relevant_docs)} 个相关文档")
return relevant_docs

top_indices = np.argsort(scores)[::-1][:k]
# 返回前k个文档
top_docs = [self.docs[idx] for idx in top_indices]
logger.debug(f"找到 {len(top_docs)} 个相关文档")
return top_docs

except Exception as e:
logger.error(f"BM25 搜索时发生错误: {e}")
raise

def get_scores(self, query: str) -> List[float]:
"""获取查询对所有文档的 BM25 分数

Expand Down
4 changes: 2 additions & 2 deletions rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import logging

from pydantic import ConfigDict, Field, model_validator
from Retrieval.RetrieverBase import BaseRetriever, Document
from Store.VectorStore.VectorStoreBase import VectorStore
from ..RetrieverBase import BaseRetriever, Document
from ...Store.VectorStore.VectorStoreBase import VectorStore

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion rag_factory/Store/VectorStore/VectorStore_Faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import numpy as np
from typing import Any, Optional, Callable
from .VectorStoreBase import VectorStore, Document
from Embed import Embeddings
from ...Embed.Embedding_Base import Embeddings
import asyncio
from concurrent.futures import ThreadPoolExecutor

# TODO 需要支持GPU,提高速度

def _mmr_select(
docs_and_scores: list[tuple[Document, float]],
Expand Down
2 changes: 1 addition & 1 deletion rag_factory/Store/VectorStore/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# VectorStore/registry.py
from typing import Dict, Type, Any, Optional
from .VectorStoreBase import VectorStore
from Embed.Embedding_Base import Embeddings
from ...Embed.Embedding_Base import Embeddings
from .VectorStore_Faiss import FaissVectorStore


Expand Down
27 changes: 27 additions & 0 deletions rag_factory/rerankers/Reranker_Base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from ..Retrieval import Document
import warnings

class RerankerBase(ABC):
"""
Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。
不建议直接实例化本类。

使用方法:
class MyReranker(RerankerBase):
def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]:
# 实现具体的重排序逻辑
...
"""
def __init__(self):
if type(self) is RerankerBase:
warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning)

@abstractmethod
def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]:
"""
Rerank the documents based on the query.
需要子类实现。
"""
warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning)
raise NotImplementedError("子类必须实现 rerank 方法。")
75 changes: 75 additions & 0 deletions rag_factory/rerankers/Reranker_Qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .Reranker_Base import RerankerBase
from ..Retrieval.RetrieverBase import Document

class Qwen3Reranker(RerankerBase):
def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs):
super().__init__()
device = torch.device(device_id)
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left')
self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
self.lm = self.lm.to(device).eval()
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
self.instruction = instruction or "Given the user query, retrieval the relevant passages"
self.device = device

def format_instruction(self, instruction, query, doc):
if instruction is None:
instruction = self.instruction
output = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
return output

def process_inputs(self, pairs):
out = self.tokenizer(
pairs, padding=False, truncation='longest_first',
return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
)
for i, ele in enumerate(out['input_ids']):
out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length)
for key in out:
out[key] = out[key].to(self.lm.device)
return out

@torch.no_grad()
def compute_logits(self, inputs, **kwargs):
batch_scores = self.lm(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

def compute_scores(self, pairs, instruction=None, **kwargs):
pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
inputs = self.process_inputs(pairs)
scores = self.compute_logits(inputs)
return scores

def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]:
# 1. 组装 (query, doc.content) 对
pairs = [(query, doc.content) for doc in documents]

# 2. 计算分数
all_scores = []
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i:i+batch_size]
batch_scores = self.compute_scores(batch_pairs)
all_scores.extend(batch_scores)
scores = all_scores

# 3. 按分数排序
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
reranked_docs = [doc for doc, score in doc_score_pairs]
if k is not None:
reranked_docs = reranked_docs[:k]
return reranked_docs
4 changes: 4 additions & 0 deletions rag_factory/rerankers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .Reranker_Base import RerankerBase
from .Reranker_Qwen3 import Qwen3Reranker

__all__ = ["RerankerBase", "Qwen3Reranker"]