Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/llmlayerwise #25

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an

## Updates

- v0.5.0: Added support for the current state-of-the-art rerankers, BAAI's series of `BGE` layerwise LLM rerankers, based on [Gemma](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) and MiniCPM. These are different from RankGPT, as they're not listwise: the models are repurposed as "cross-encoders", and do output logit scores.
- v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan).
- v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers.
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.3.0: Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
- v0.1.2: Voyage reranking API
- v0.1.1: Langchain integration fixed!
Expand Down Expand Up @@ -198,6 +199,7 @@ Models:
- ✅ Any standard SentenceTransformer or Transformers cross-encoder
- ✅ RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one)
- ✅ T5-based pointwise rankers (InRanker, MonoT5...)
- ✅ LLM-based pointwise rankers (BAAI/bge-reranker-v2.5-gemma2-lightweight, etc...)
- ✅ Cohere, Jina, Voyage and MixedBread API rerankers
- ✅ [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU)
- ✅ ColBERT-based reranker - not a model initially designed for reranking, but does perform quite strongly in some cases. Implementation is lightweight, based only on transformers.
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.4.0"
version = "0.5.0"

description = "A unified API for various document re-ranking models."

Expand Down Expand Up @@ -60,13 +60,15 @@ all = [
"sentencepiece",
"protobuf",
"flashrank",
"flash-attn",
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
api = ["requests"]
gpt = ["litellm"]
flashrank = ["flashrank"]
llmlayerwise = ["transformers", "torch", "sentencepiece", "protobuf", "flash-attn"]
rankllm = [
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
Expand Down
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.4.0"
__version__ = "0.5.0"
7 changes: 7 additions & 0 deletions rerankers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@
AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker
except ImportError:
pass

try:
from rerankers.models.llm_layerwise_ranker import LLMLayerWiseRanker

AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker
except ImportError:
pass
198 changes: 198 additions & 0 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rerankers.models.ranker import BaseRanker
from rerankers.documents import Document
from typing import Union, List, Optional
from rerankers.utils import vprint, get_device, get_dtype, prep_docs
from rerankers.results import RankedResults, Result


PROMPTS = {
"BAAI/bge-reranker-v2.5-gemma2-lightweight": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
"default": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
}

DEFAULT_PARAMS = {
"default": {},
"BAAI/bge-multilingual-gemma2": {},
"BAAI/bge-reranker-v2-gemma": {},
"BAAI/bge-reranker-v2-minicpm-layerwise": {"cutoff_layers": [28]},
"BAAI/bge-reranker-v2.5-gemma2-lightweight": {
"cutoff_layers": [28],
"compress_ratio": 2,
"compress_layer": [24, 40],
},
}


class LLMLayerWiseRanker(BaseRanker):
def __init__(
self,
model_name_or_path: str = "BAAI/bge-reranker-v2.5-gemma2-lightweight",
max_sequence_length: int = 512,
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 16,
verbose: int = 1,
prompt: Optional[str] = None,
cutoff_layers: Optional[List[int]] = None,
compress_ratio: Optional[int] = None,
compress_layer: Optional[List[int]] = None,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
self.dtype = get_dtype(dtype, self.device, self.verbose)
self.batch_size = batch_size

vprint(
f"Loading model {model_name_or_path}, this might take a while...",
self.verbose,
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.max_sequence_length = max_sequence_length
self.tokenizer.model_max_length = self.max_sequence_length
self.tokenizer.padding_side = "right"

self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype
).to(self.device)
self.model.eval()

# Create params dict based on specified values or defaults
params = {}
if cutoff_layers is not None:
params["cutoff_layers"] = cutoff_layers
if compress_ratio is not None:
params["compress_ratio"] = compress_ratio
if compress_layer is not None:
params["compress_layer"] = compress_layer
if not params:
params = DEFAULT_PARAMS.get(model_name_or_path, DEFAULT_PARAMS["default"])
self.params = params

self.prompt = prompt
if self.prompt is None:
self.prompt = PROMPTS.get(model_name_or_path, PROMPTS["default"])

def _get_inputs(self, pairs, max_sequence_length: int):
prompt = self.prompt
sep = "\n"
prompt_inputs = self.tokenizer(
prompt, return_tensors=None, add_special_tokens=False
)["input_ids"]
sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[
"input_ids"
]
inputs = []
for query, passage in pairs:
query_inputs = self.tokenizer(
f"A: {query}",
return_tensors=None,
add_special_tokens=False,
max_length=max_sequence_length * 3 // 4,
truncation=True,
)
passage_inputs = self.tokenizer(
f"B: {passage}",
return_tensors=None,
add_special_tokens=False,
max_length=max_sequence_length,
truncation=True,
)
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs["input_ids"],
sep_inputs + passage_inputs["input_ids"],
truncation="only_second",
max_length=max_sequence_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False,
)
item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
item["attention_mask"] = [1] * len(item["input_ids"])
inputs.append(item)

return self.tokenizer.pad(
inputs,
padding=True,
max_length=max_sequence_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors="pt",
)

@torch.no_grad()
def rank(
self,
query: str,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
batch_size: Optional[int] = None,
max_sequence_length: Optional[int] = None,
) -> RankedResults:
docs = prep_docs(docs, doc_ids, metadata)
pairs = [(query, doc.text) for doc in docs]

# Override self.batch_size if explicitly set
if batch_size is None:
batch_size = self.batch_size

# Same for max_sequence_length
if max_sequence_length is None:
max_sequence_length = self.max_sequence_length

batched_pairs = [
pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)
]
scores = []

for batch in batched_pairs:
inputs = self._get_inputs(batch, max_sequence_length=max_sequence_length)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

outputs = self.model(**inputs, return_dict=True, **self.params)
all_scores = [
scores[:, -1]
.view(
-1,
)
.float()
for scores in outputs[0]
]
batch_scores = all_scores[-1].cpu().numpy().tolist()

scores.extend(batch_scores)

ranked_results = [
Result(document=doc, score=score, rank=idx + 1)
for idx, (doc, score) in enumerate(
sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
def score(self, query: str, doc: str) -> float:
inputs = self._get_inputs(
[(query, doc)], max_sequence_length=self.max_sequence_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

outputs = self.model(**inputs, return_dict=True, **self.params)
all_scores = [
scores[:, -1]
.view(
-1,
)
.float()
for scores in outputs[0]
]
score = all_scores[-1].item()

return score
10 changes: 8 additions & 2 deletions rerankers/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
"es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
},
"flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"},
"llm-layerwise": {
"en": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
"other": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
},
}

DEPS_MAPPING = {
Expand All @@ -42,6 +46,7 @@
"ColBERTRanker": "transformers",
"FlashRankRanker": "flashrank",
"RankLLMRanker": "rankllm",
"LLMLayerWiseRanker": "transformers",
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]
Expand Down Expand Up @@ -78,6 +83,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"cross-encoder": "TransformerRanker",
"flashrank": "FlashRankRanker",
"rankllm": "RankLLMRanker",
"llm-layerwise": "LLMLayerWiseRanker",
}
return model_mapping.get(explicit_model_type, explicit_model_type)
else:
Expand All @@ -89,7 +95,6 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"rankllm": "RankLLMRanker",
"rankgpt": "RankGPTRanker",
"gpt": "RankGPTRanker",
"zephyr": "RankZephyr",
"colbert": "ColBERTRanker",
"cohere": "APIRanker",
"jina": "APIRanker",
Expand All @@ -99,6 +104,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"ms-marco-multibert-l-12": "FlashRankRanker",
"vicuna": "RankLLMRanker",
"zephyr": "RankLLMRanker",
"bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker",
}
for key, value in model_mapping.items():
if key in model_name:
Expand Down