Skip to content

Commit

Permalink
Feature: integrate sentence transformers into retriever module
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Jun 11, 2024
1 parent 917be8c commit 12e330e
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 29 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ FlashRAG is still under development and there are many issues and room for impro
- [x] Support OpenAI models
- [ ] Support Claude and Gemini models
- [x] Provdide instructions for each component
- [ ] Integrate sentence Transformers
- [x] Integrate sentence Transformers
- [ ] Inlcude more RAG approaches
- [ ] Add more evaluation metrics (e.g., Unieval, name-entity F1) and benchmarks (e.g., RGB benchmark)
- [ ] Enhance code adaptability and readability


## :page_with_curl: Changelog

[24/06/11] We have integrated `sentence transformers` in the retriever module. More retrievers can be used.

[24/06/05] We have provided detailed document for reproducing existing methods (see [how to reproduce ](./docs/reproduce_experiment.md) and [baseline details](./docs/baseline_details.md)) and [<u>configurations settings</u>](./docs/configuration.md)

[24/06/02] We have provided an introduction of FlashRAG for beginners, see [<u>a introduction to flashrag</u>](./docs/instruction_for_beginners_en.md) ([<u>中文版</u>](./docs/introduction_for_beginners_zh.md))
Expand Down
20 changes: 20 additions & 0 deletions docs/building-index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ python -m flashrag.retriever.index_builder \

* ```--pooling_method```: If this is not specified, we will automatically select based on the model name. However, due to the different pooling methods used by different embedding models, **we may not have fully implemented them**. To ensure accuracy, you can **specify the pooling method corresponding to the retrieval model** you are using (`mean`, `pooler` or `cls`).


If the retrieval model support `sentence transformers` library, you can use following code to build index (**no need to consider pooling method**).

```bash
python -m flashrag.retriever.index_builder \
--retrieval_method e5 \
--model_path /model/e5-base-v2/ \
--corpus_path indexes/sample_corpus.jsonl \
--save_dir indexes/ \
--use_fp16 \
--max_length 200 \
--batch_size 32 \
--pooling_method mean \
--sentence_transformer \
--faiss_type Flat
```




#### For sparse retrieval method (BM25)

If building a bm25 index, there is no need to specify `model_path`:
Expand Down
1 change: 1 addition & 0 deletions flashrag/config/basic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ index_path: ~ # set automatically if not provided.
faiss_gpu: False # whether use gpu to hold index
corpus_path: ~ # path to corpus in '.jsonl' format that store the documents

use_sentence_transformer: False
retrieval_topk: 5 # number of retrieved documents
retrieval_batch_size: 256 # batch size for retrieval
retrieval_use_fp16: True # whether to use fp16 for retrieval model
Expand Down
79 changes: 64 additions & 15 deletions flashrag/retriever/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@
from flashrag.retriever.utils import load_model, pooling


def parse_query(model_name, query_list, is_query=True):
"""
processing query for different encoders
"""

if isinstance(query_list, str):
query_list = [query_list]

if "e5" in model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]

if "bge" in model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]

return query_list


class Encoder:
def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
self.model_name = model_name
Expand All @@ -17,19 +38,7 @@ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16)

@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]

if "e5" in self.model_name.lower():
if is_query:
query_list = [f"query: {query}" for query in query_list]
else:
query_list = [f"passage: {query}" for query in query_list]

if "bge" in self.model_name.lower():
if is_query:
query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
query_list = parse_query(self.model_name, query_list, is_query)

inputs = self.tokenizer(query_list,
max_length=self.max_length,
Expand All @@ -55,9 +64,49 @@ def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
output.last_hidden_state,
inputs['attention_mask'],
self.pooling_method)
if "dpr" not in self.model_name.lower():
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)

query_emb = query_emb.detach().cpu().numpy()
query_emb = query_emb.astype(np.float32, order="C")
return query_emb

class STEncoder:
def __init__(self, model_name, model_path, max_length, use_fp16):
import torch
from sentence_transformers import SentenceTransformer

self.model_name = model_name
self.model_path = model_path
self.max_length = max_length
self.use_fp16 = use_fp16

self.model = SentenceTransformer(model_path, model_kwargs = {"torch_dtype": torch.float16 if use_fp16 else torch.float})


@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
query_list = parse_query(self.model_name, query_list, is_query)
query_emb = self.model.encode(
query_list,
batch_size = len(query_list),
convert_to_numpy = True,
normalize_embeddings = True
)
query_emb = query_emb.astype(np.float32, order="C")

return query_emb

@torch.no_grad()
def multi_gpu_encode(self, query_list: List[str], is_query=True, batch_size=None) -> np.ndarray:
query_list = parse_query(self.model_name, query_list, is_query)
pool = self.model.start_multi_process_pool()
query_emb = self.model.encode_multi_process(
query_list, pool,
convert_to_numpy = True,
normalize_embeddings = True,
batch_size = batch_size
)
self.model.stop_multi_process_pool(pool)
query_emb.astype(np.float32, order="C")

return query_emb

42 changes: 36 additions & 6 deletions flashrag/retriever/index_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(
faiss_type=None,
embedding_path=None,
save_embedding=False,
faiss_gpu=False
faiss_gpu=False,
use_sentence_transformer=False
):

self.retrieval_method = retrieval_method.lower()
Expand All @@ -44,6 +45,7 @@ def __init__(
self.embedding_path = embedding_path
self.save_embedding = save_embedding
self.faiss_gpu = faiss_gpu
self.use_sentence_transformer = use_sentence_transformer

self.gpu_num = torch.cuda.device_count()
# prepare save dir
Expand Down Expand Up @@ -141,6 +143,21 @@ def _save_embedding(self, all_embeddings):
else:
memmap[:] = all_embeddings

def st_encode_all(self):
if self.gpu_num > 1:
print("Use multi gpu!")
self.batch_size = self.batch_size * self.gpu_num

sentence_list = [item['contents'] for item in self.corpus]
if self.retrieval_method == "e5":
sentence_list = [f"passage: {doc}" for doc in sentence_list]
all_embeddings = self.encode(
sentence_list,
batch_size = self.batch_size
)

return all_embeddings

def encode_all(self):
if self.gpu_num > 1:
print("Use multi gpu!")
Expand Down Expand Up @@ -204,14 +221,25 @@ def build_dense_index(self):
if os.path.exists(self.index_save_path):
print("The index file already exists and will be overwritten.")

self.encoder, self.tokenizer = load_model(model_path = self.model_path,
use_fp16 = self.use_fp16)
if self.embedding_path is not None:
if self.use_sentence_transformer:
from flashrag.retriever.encoder import STEncoder
self.encoder = STEncoder(
model_name = self.retrieval_method,
model_path = self.model_path,
max_length = self.max_length,
use_fp16 = self.use_fp16
)
hidden_size = self.encoder.model.get_sentence_embedding_dimension()
else:
self.encoder, self.tokenizer = load_model(model_path = self.model_path,
use_fp16 = self.use_fp16)
hidden_size = self.encoder.config.hidden_size

if self.embedding_path is not None:
corpus_size = len(self.corpus)
all_embeddings = self._load_embedding(self.embedding_path, corpus_size, hidden_size)
else:
all_embeddings = self.encode_all()
all_embeddings = self.st_encode_all() if self.use_sentence_transformer else self.encode_all()
if self.save_embedding:
self._save_embedding(all_embeddings)
del self.corpus
Expand Down Expand Up @@ -265,6 +293,7 @@ def main():
parser.add_argument('--embedding_path', default=None, type=str)
parser.add_argument('--save_embedding', action='store_true', default=False)
parser.add_argument('--faiss_gpu', default=False, action='store_true')
parser.add_argument('--sentence_transformer', action='store_true', default=False)

args = parser.parse_args()

Expand Down Expand Up @@ -293,7 +322,8 @@ def main():
faiss_type = args.faiss_type,
embedding_path = args.embedding_path,
save_embedding = args.save_embedding,
faiss_gpu = args.faiss_gpu
faiss_gpu = args.faiss_gpu,
use_sentence_transformer = args.sentence_transformer
)
index_builder.build_index()

Expand Down
23 changes: 16 additions & 7 deletions flashrag/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from flashrag.utils import get_reranker
from flashrag.retriever.utils import load_corpus, load_docs
from flashrag.retriever.encoder import Encoder
from flashrag.retriever.encoder import Encoder, STEncoder


def cache_manager(func):
Expand Down Expand Up @@ -235,12 +235,21 @@ def __init__(self, config: dict):
self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)

self.corpus = load_corpus(self.corpus_path)
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config['retrieval_model_path'],
pooling_method = config['retrieval_pooling_method'],
max_length = config['retrieval_query_max_length'],
use_fp16 = config['retrieval_use_fp16']

if config['use_sentence_transformer']:
self.encoder = STEncoder(
model_name = self.retrieval_method,
model_path = config['retrieval_model_path'],
max_length = config['retrieval_query_max_length'],
use_fp16 = config['retrieval_use_fp16']
)
else:
self.encoder = Encoder(
model_name = self.retrieval_method,
model_path = config['retrieval_model_path'],
pooling_method = config['retrieval_pooling_method'],
max_length = config['retrieval_query_max_length'],
use_fp16 = config['retrieval_use_fp16']
)
self.topk = config['retrieval_topk']
self.batch_size = self.config['retrieval_batch_size']
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tqdm
transformers>=4.40.0
vllm>=0.4.1
voyageai
sentence-transformers

0 comments on commit 12e330e

Please sign in to comment.