-
Notifications
You must be signed in to change notification settings - Fork 500
feat: add search encoder backend #3492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Samoed
wants to merge
10
commits into
main
Choose a base branch
from
search_barckend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+486
−85
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
c8b2bd3
add search backend
Samoed 8a3527f
make faiss optional
Samoed b2c3f60
fix import
Samoed 51111ca
use faiss in reranking
Samoed ae31d1b
add support for multiple similarities
Samoed 2ce10fd
remove check
Samoed 74458c5
update index check
Samoed 05b0ba8
rename and move files
Samoed 48143c0
add missing files
Samoed 7fbc60f
fix import
Samoed File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from .search_backend_protocol import IndexEncoderSearchProtocol | ||
| from .search_indexes import FaissSearchIndex, StreamingSearchIndex | ||
|
|
||
| __all__ = [ | ||
| "FaissSearchIndex", | ||
| "IndexEncoderSearchProtocol", | ||
| "StreamingSearchIndex", | ||
| ] |
50 changes: 50 additions & 0 deletions
50
mteb/models/search_encoder_index/search_backend_protocol.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,50 @@ | ||||||
| from collections.abc import Callable | ||||||
| from typing import Protocol | ||||||
|
|
||||||
| from mteb.types import Array, TopRankedDocumentsType | ||||||
|
|
||||||
|
|
||||||
| class IndexEncoderSearchProtocol(Protocol): | ||||||
| """Protocol for search backends used in encoder-based retrieval.""" | ||||||
|
|
||||||
| def add_document( | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| self, | ||||||
| embeddings: Array, | ||||||
| idxs: list[str], | ||||||
| ) -> None: | ||||||
| """Add documents to the search backend. | ||||||
|
|
||||||
| Args: | ||||||
| embeddings: Embeddings of the documents to add. | ||||||
| idxs: IDs of the documents to add. | ||||||
| """ | ||||||
|
|
||||||
| def search( | ||||||
| self, | ||||||
| embeddings: Array, | ||||||
| top_k: int, | ||||||
| similarity_fn: Callable[[Array, Array], Array], | ||||||
| top_ranked: TopRankedDocumentsType | None = None, | ||||||
| query_idx_to_id: dict[int, str] | None = None, | ||||||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||||||
| """Search through added corpus embeddings or rerank top-ranked documents. | ||||||
|
|
||||||
| Supports both full-corpus and reranking search modes: | ||||||
| - Full-corpus mode: `top_ranked=None`, uses added corpus embeddings. | ||||||
| - Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}. | ||||||
|
|
||||||
| Args: | ||||||
| embeddings: Query embeddings, shape (num_queries, dim). | ||||||
| top_k: Number of top results to return. | ||||||
| similarity_fn: Function to compute similarity between query and corpus. | ||||||
| top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking. | ||||||
| query_idx_to_id: Mapping of query index -> query_id. Used for reranking. | ||||||
|
|
||||||
| Returns: | ||||||
| A tuple (top_k_values, top_k_indices), for each query: | ||||||
| - top_k_values: List of top-k similarity scores. | ||||||
| - top_k_indices: List of indices of the top-k documents in the added corpus. | ||||||
| """ | ||||||
|
|
||||||
| def clear(self) -> None: | ||||||
| """Clear all stored documents and embeddings from the backend.""" | ||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from .faiss_search_index import FaissSearchIndex | ||
| from .streaming_search_index import StreamingSearchIndex | ||
|
|
||
| __all__ = [ | ||
| "FaissSearchIndex", | ||
| "StreamingSearchIndex", | ||
| ] |
157 changes: 157 additions & 0 deletions
157
mteb/models/search_encoder_index/search_indexes/faiss_search_index.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from mteb._requires_package import requires_package | ||
| from mteb.models.model_meta import ScoringFunction | ||
| from mteb.models.models_protocols import EncoderProtocol | ||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FaissSearchIndex: | ||
| """FAISS-based backend for encoder-based search. | ||
|
|
||
| Supports both full-corpus retrieval and reranking (via `top_ranked`). | ||
|
|
||
| Notes: | ||
| - Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2). | ||
| - Expects embeddings to be normalized if cosine similarity is desired. | ||
| """ | ||
|
|
||
| _normalize: bool = False | ||
|
|
||
| def __init__(self, model: EncoderProtocol) -> None: | ||
| requires_package( | ||
| self, | ||
| "faiss", | ||
| "FAISS-based search", | ||
| install_instruction="pip install mteb[faiss-cpu]", | ||
| ) | ||
|
|
||
| import faiss | ||
| from faiss import IndexFlatIP, IndexFlatL2 | ||
|
|
||
| # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes | ||
| if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT: | ||
| self.index_type = IndexFlatIP | ||
| elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE: | ||
| self.index_type = IndexFlatIP | ||
| self._normalize = True | ||
| elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN: | ||
| self.index_type = IndexFlatL2 | ||
| else: | ||
| raise ValueError( | ||
| f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. " | ||
| f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}." | ||
| ) | ||
|
|
||
| self.idxs: list[str] = [] | ||
| self.index: faiss.Index | None = None | ||
|
|
||
| def add_document(self, embeddings: Array, idxs: list[str]) -> None: | ||
| """Add all document embeddings and their IDs to FAISS index.""" | ||
| import faiss | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| embeddings = embeddings.astype(np.float32) | ||
| self.idxs.extend(idxs) | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| dim = embeddings.shape[1] | ||
| if self.index is None: | ||
| self.index = self.index_type(dim) | ||
|
|
||
| self.index.add(embeddings) | ||
| logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.") | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search using FAISS.""" | ||
| import faiss | ||
|
|
||
| if self.index is None: | ||
| raise ValueError("No index built. Call add_document() first.") | ||
|
|
||
| if isinstance(embeddings, torch.Tensor): | ||
| embeddings = embeddings.detach().cpu().numpy() | ||
|
|
||
| if self._normalize: | ||
| faiss.normalize_L2(embeddings) | ||
|
|
||
| if top_ranked is not None: | ||
| if query_idx_to_id is None: | ||
| raise ValueError("query_idx_to_id must be provided when reranking.") | ||
|
|
||
| similarities, ids = self._reranking( | ||
| embeddings, | ||
| top_k, | ||
| top_ranked=top_ranked, | ||
| query_idx_to_id=query_idx_to_id, | ||
| ) | ||
| else: | ||
| similarities, ids = self.index.search(embeddings.astype(np.float32), top_k) | ||
| similarities = similarities.tolist() | ||
| ids = ids.tolist() | ||
|
|
||
| if issubclass(self.index_type, faiss.IndexFlatL2): | ||
| similarities = -np.sqrt(np.maximum(similarities, 0)) | ||
|
|
||
| return similarities, ids | ||
|
|
||
| def _reranking( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} | ||
| scores_all: list[list[float]] = [] | ||
| idxs_all: list[list[int]] = [] | ||
|
|
||
| for query_idx, query_emb in enumerate(embeddings): | ||
| query_id = query_idx_to_id[query_idx] | ||
| ranked_ids = top_ranked.get(query_id) | ||
| if not ranked_ids: | ||
| logger.warning(f"No top-ranked documents for query {query_id}") | ||
| scores_all.append([]) | ||
| idxs_all.append([]) | ||
| continue | ||
|
|
||
| candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] | ||
| d = self.index.d | ||
| candidate_embs = np.vstack( | ||
| [self.index.reconstruct(idx) for idx in candidate_indices] | ||
| ) | ||
| sub_reranking_index = self.index_type(d) | ||
| sub_reranking_index.add(candidate_embs) | ||
|
|
||
| # Search returns scores and indices in one call | ||
| scores, local_indices = sub_reranking_index.search( | ||
| query_emb.reshape(1, -1).astype(np.float32), | ||
| min(top_k, len(candidate_indices)), | ||
| ) | ||
| # faiss will output 2d arrays even for single query | ||
| scores_all.append(scores[0].tolist()) | ||
| idxs_all.append(local_indices[0].tolist()) | ||
|
|
||
| return scores_all, idxs_all | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| self.index = None | ||
| self.idxs = [] |
99 changes: 99 additions & 0 deletions
99
mteb/models/search_encoder_index/search_indexes/streaming_search_index.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from mteb.types import Array, TopRankedDocumentsType | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class StreamingSearchIndex: | ||
| """Streaming backend for encoder-based search. | ||
|
|
||
| - Does not store the entire corpus in memory. | ||
| - Encodes and searches corpus in chunks. | ||
| """ | ||
|
|
||
| sub_corpus_embeddings: Array | None = None | ||
| idxs: list[str] | ||
|
|
||
| def add_document( | ||
| self, | ||
| embeddings: Array, | ||
| idxs: list[str], | ||
| ) -> None: | ||
| """Add all document embeddings and their IDs to the backend.""" | ||
| self.sub_corpus_embeddings = embeddings | ||
| self.idxs = idxs | ||
|
|
||
| def search( | ||
| self, | ||
| embeddings: Array, | ||
| top_k: int, | ||
| similarity_fn: Callable[[Array, Array], Array], | ||
| top_ranked: TopRankedDocumentsType | None = None, | ||
| query_idx_to_id: dict[int, str] | None = None, | ||
| ) -> tuple[list[list[float]], list[list[int]]]: | ||
| """Search through added corpus embeddings or rerank top-ranked documents.""" | ||
| if self.sub_corpus_embeddings is None: | ||
| raise ValueError("No corpus embeddings found. Did you call add_document()?") | ||
|
|
||
| if top_ranked is not None: | ||
| if query_idx_to_id is None: | ||
| raise ValueError("query_idx_to_id is required when using top_ranked.") | ||
|
|
||
| scores_all: list[list[float]] = [] | ||
| idxs_all: list[list[int]] = [] | ||
|
|
||
| doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)} | ||
|
|
||
| for query_idx, query_emb in enumerate(embeddings): | ||
| query_id = query_idx_to_id[query_idx] | ||
| ranked_ids = top_ranked.get(query_id) | ||
| if not ranked_ids: | ||
| logger.warning(f"No top-ranked docs for query {query_id}") | ||
| scores_all.append([]) | ||
| idxs_all.append([]) | ||
| continue | ||
|
|
||
| candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids] | ||
| candidate_embs = self.sub_corpus_embeddings[candidate_idx] | ||
|
|
||
| scores = similarity_fn( | ||
| torch.as_tensor(query_emb).unsqueeze(0), | ||
| torch.as_tensor(candidate_embs), | ||
| ) | ||
|
|
||
| values, indices = torch.topk( | ||
| torch.as_tensor(scores), | ||
| k=min(top_k, len(candidate_idx)), | ||
| dim=1, | ||
| largest=True, | ||
| ) | ||
| scores_all.append(values.squeeze(0).cpu().tolist()) | ||
| idxs_all.append(indices.squeeze(0).cpu().tolist()) | ||
|
|
||
| return scores_all, idxs_all | ||
|
|
||
| scores = similarity_fn(embeddings, self.sub_corpus_embeddings) | ||
| self.sub_corpus_embeddings = None | ||
|
|
||
| cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( | ||
| torch.tensor(scores), | ||
| min( | ||
| top_k + 1, | ||
| len(scores[1]) if len(scores) > 1 else len(scores[-1]), | ||
| ), | ||
| dim=1, | ||
| largest=True, | ||
| ) | ||
| return ( | ||
| cos_scores_top_k_values.cpu().tolist(), | ||
| cos_scores_top_k_idx.cpu().tolist(), | ||
| ) | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear all stored documents and embeddings from the backend.""" | ||
| self.sub_corpus_embeddings = None | ||
| self.idxs = [] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can specify that this is only for index and only for encoder, because this can be confused that
SentenceTransformerEncoderWrapperwill implement it (probably)