Skip to content

Commit c8b2bd3

Browse files
committed
add search backend
1 parent fcd3b71 commit c8b2bd3

File tree

6 files changed

+381
-68
lines changed

6 files changed

+381
-68
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .default_backend_search import DefaultEncoderSearchBackend
2+
from .faiss_search_backend import FaissEncoderSearchBackend
3+
from .search_backend_protocol import IndexEncoderSearchProtocol
4+
5+
__all__ = [
6+
"DefaultEncoderSearchBackend",
7+
"FaissEncoderSearchBackend",
8+
"IndexEncoderSearchProtocol",
9+
]
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import logging
2+
from collections.abc import Callable
3+
4+
import torch
5+
6+
from mteb.types import Array, TopRankedDocumentsType
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class DefaultEncoderSearchBackend:
12+
"""Streaming backend for encoder-based search.
13+
14+
- Does not store the entire corpus in memory.
15+
- Encodes and searches corpus in chunks.
16+
"""
17+
18+
sub_corpus_embeddings: Array | None = None
19+
idxs: list[str]
20+
21+
def add_document(
22+
self,
23+
embeddings: Array,
24+
idxs: list[str],
25+
) -> None:
26+
"""Add all document embeddings and their IDs to the backend."""
27+
self.sub_corpus_embeddings = embeddings
28+
self.idxs = idxs
29+
30+
def search(
31+
self,
32+
embeddings: Array,
33+
top_k: int,
34+
similarity_fn: Callable[[Array, Array], Array],
35+
top_ranked: TopRankedDocumentsType | None = None,
36+
query_idx_to_id: dict[int, str] | None = None,
37+
) -> tuple[list[list[float]], list[list[int]]]:
38+
"""Search through added corpus embeddings or rerank top-ranked documents."""
39+
if self.sub_corpus_embeddings is None:
40+
raise ValueError("No corpus embeddings found. Did you call add_document()?")
41+
42+
if top_ranked is not None:
43+
if query_idx_to_id is None:
44+
raise ValueError("query_idx_to_id is required when using top_ranked.")
45+
46+
scores_all: list[list[float]] = []
47+
idxs_all: list[list[int]] = []
48+
49+
doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
50+
51+
for query_idx, query_emb in enumerate(embeddings):
52+
query_id = query_idx_to_id[query_idx]
53+
ranked_ids = top_ranked.get(query_id)
54+
if not ranked_ids:
55+
logger.warning(f"No top-ranked docs for query {query_id}")
56+
scores_all.append([])
57+
idxs_all.append([])
58+
continue
59+
60+
candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
61+
candidate_embs = self.sub_corpus_embeddings[candidate_idx]
62+
63+
scores = similarity_fn(
64+
torch.as_tensor(query_emb).unsqueeze(0),
65+
torch.as_tensor(candidate_embs),
66+
)
67+
68+
values, indices = torch.topk(
69+
torch.as_tensor(scores),
70+
k=min(top_k, len(candidate_idx)),
71+
dim=1,
72+
largest=True,
73+
)
74+
scores_all.append(values.squeeze(0).cpu().tolist())
75+
idxs_all.append(indices.squeeze(0).cpu().tolist())
76+
77+
return scores_all, idxs_all
78+
79+
scores = similarity_fn(embeddings, self.sub_corpus_embeddings)
80+
self.sub_corpus_embeddings = None
81+
82+
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
83+
torch.tensor(scores),
84+
min(
85+
top_k + 1,
86+
len(scores[1]) if len(scores) > 1 else len(scores[-1]),
87+
),
88+
dim=1,
89+
largest=True,
90+
)
91+
return (
92+
cos_scores_top_k_values.cpu().tolist(),
93+
cos_scores_top_k_idx.cpu().tolist(),
94+
)
95+
96+
def clear(self) -> None:
97+
"""Clear all stored documents and embeddings from the backend."""
98+
self.sub_corpus_embeddings = None
99+
self.idxs = []
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import logging
2+
from collections.abc import Callable
3+
4+
import faiss
5+
import numpy as np
6+
import torch
7+
from faiss import IndexFlatIP, IndexFlatL2
8+
9+
from mteb import EncoderProtocol
10+
from mteb.models.model_meta import ScoringFunction
11+
from mteb.types import Array, TopRankedDocumentsType
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class FaissEncoderSearchBackend:
17+
"""FAISS-based backend for encoder-based search.
18+
19+
Supports both full-corpus retrieval and reranking (via `top_ranked`).
20+
21+
Notes:
22+
- Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2).
23+
- Expects embeddings to be normalized if cosine similarity is desired.
24+
"""
25+
26+
_normalize: bool = False
27+
28+
def __init__(self, model: EncoderProtocol) -> None:
29+
# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
30+
if (
31+
model.mteb_model_meta.similarity_fn_name == "dot"
32+
or model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT
33+
):
34+
self.index_type = IndexFlatL2
35+
else:
36+
self.index_type = IndexFlatIP
37+
self._normalize = True
38+
39+
self.idxs: list[str] = []
40+
self.index: faiss.Index | None = None
41+
42+
def add_document(self, embeddings: Array, idxs: list[str]) -> None:
43+
"""Add all document embeddings and their IDs to FAISS index."""
44+
if isinstance(embeddings, torch.Tensor):
45+
embeddings = embeddings.detach().cpu().numpy()
46+
47+
embeddings = embeddings.astype(np.float32)
48+
self.idxs.extend(idxs)
49+
50+
if self._normalize:
51+
faiss.normalize_L2(embeddings)
52+
53+
dim = embeddings.shape[1]
54+
if self.index is None:
55+
self.index = self.index_type(dim)
56+
57+
self.index.add(embeddings)
58+
logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.")
59+
60+
def search(
61+
self,
62+
embeddings: Array,
63+
top_k: int,
64+
similarity_fn: Callable[[Array, Array], Array],
65+
top_ranked: TopRankedDocumentsType | None = None,
66+
query_idx_to_id: dict[int, str] | None = None,
67+
) -> tuple[list[list[float]], list[list[int]]]:
68+
"""Search using FAISS."""
69+
if self.index is None:
70+
raise ValueError("No index built. Call add_document() first.")
71+
72+
if isinstance(embeddings, torch.Tensor):
73+
embeddings = embeddings.detach().cpu().numpy()
74+
75+
if self._normalize:
76+
faiss.normalize_L2(embeddings)
77+
78+
if top_ranked is not None:
79+
if query_idx_to_id is None:
80+
raise ValueError("query_idx_to_id must be provided when reranking.")
81+
82+
doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
83+
scores_all: list[list[float]] = []
84+
idxs_all: list[list[int]] = []
85+
86+
for query_idx, query_emb in enumerate(embeddings):
87+
query_id = query_idx_to_id[query_idx]
88+
ranked_ids = top_ranked.get(query_id)
89+
if not ranked_ids:
90+
logger.warning(f"No top-ranked documents for query {query_id}")
91+
scores_all.append([])
92+
idxs_all.append([])
93+
continue
94+
95+
candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
96+
d = self.index.d
97+
candidate_embs = np.zeros((len(candidate_indices), d), dtype=np.float32)
98+
for j, idx in enumerate(candidate_indices):
99+
candidate_embs[j] = self.index.reconstruct(idx)
100+
101+
scores = similarity_fn(
102+
torch.as_tensor(query_emb).unsqueeze(0),
103+
torch.as_tensor(candidate_embs),
104+
)
105+
106+
values, indices = torch.topk(
107+
torch.as_tensor(scores),
108+
k=min(top_k, len(candidate_indices)),
109+
dim=1,
110+
largest=True,
111+
)
112+
scores_all.append(values.squeeze(0).cpu().tolist())
113+
idxs_all.append(indices.squeeze(0).cpu().tolist())
114+
115+
return scores_all, idxs_all
116+
117+
documents, ids = self.index.search(embeddings.astype(np.float32), top_k)
118+
return documents.tolist(), ids.tolist()
119+
120+
def clear(self) -> None:
121+
"""Clear all stored documents and embeddings from the backend."""
122+
self.index = None
123+
self.idxs = []
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections.abc import Callable
2+
from typing import Protocol
3+
4+
from mteb.types import Array, TopRankedDocumentsType
5+
6+
7+
class IndexEncoderSearchProtocol(Protocol):
8+
"""Protocol for search backends used in encoder-based retrieval."""
9+
10+
def add_document(
11+
self,
12+
embeddings: Array,
13+
idxs: list[str],
14+
) -> None:
15+
"""Add documents to the search backend.
16+
17+
Args:
18+
embeddings: Embeddings of the documents to add.
19+
idxs: IDs of the documents to add.
20+
"""
21+
22+
def search(
23+
self,
24+
embeddings: Array,
25+
top_k: int,
26+
similarity_fn: Callable[[Array, Array], Array],
27+
top_ranked: TopRankedDocumentsType | None = None,
28+
query_idx_to_id: dict[int, str] | None = None,
29+
) -> tuple[list[list[float]], list[list[int]]]:
30+
"""Search through added corpus embeddings or rerank top-ranked documents.
31+
32+
Supports both full-corpus and reranking search modes:
33+
- Full-corpus mode: `top_ranked=None`, uses added corpus embeddings.
34+
- Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}.
35+
36+
Args:
37+
embeddings: Query embeddings, shape (num_queries, dim).
38+
top_k: Number of top results to return.
39+
similarity_fn: Function to compute similarity between query and corpus.
40+
top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking.
41+
query_idx_to_id: Mapping of query index -> query_id. Used for reranking.
42+
43+
Returns:
44+
A tuple (top_k_values, top_k_indices), for each query:
45+
- top_k_values: List of top-k similarity scores.
46+
- top_k_indices: List of indices of the top-k documents in the added corpus.
47+
"""
48+
49+
def clear(self) -> None:
50+
"""Clear all stored documents and embeddings from the backend."""

0 commit comments

Comments
 (0)