Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions mteb/models/model_implementations/random_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from mteb.abstasks.task_metadata import TaskMetadata
from mteb.models.model_meta import ModelMeta
from mteb.similarity_functions import (
select_pairwise_similarity,
select_similarity,
)
from mteb.types._encoder_io import Array, BatchedInput, PromptType


Expand Down Expand Up @@ -155,15 +159,9 @@ def similarity(
Returns:
Cosine similarity matrix between the two sets of embeddings
"""
norm1 = np.linalg.norm(
embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True
)
norm2 = np.linalg.norm(
embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True
return select_similarity(
embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name
)
normalized1 = embeddings1 / (norm1 + 1e-10)
normalized2 = embeddings2 / (norm2 + 1e-10)
return np.dot(normalized1, normalized2.T)

def similarity_pairwise(
self,
Expand All @@ -179,17 +177,9 @@ def similarity_pairwise(
Returns:
Cosine similarity for each pair of embeddings
"""
norm1 = np.linalg.norm(
embeddings1.reshape(-1, self.embedding_dim), axis=1, keepdims=True
)
norm2 = np.linalg.norm(
embeddings2.reshape(-1, self.embedding_dim), axis=1, keepdims=True
return select_pairwise_similarity(
embeddings1, embeddings2, self.mteb_model_meta.similarity_fn_name
)
normalized1 = embeddings1 / (norm1 + 1e-10)
normalized2 = embeddings2 / (norm2 + 1e-10)
normalized1 = np.asarray(normalized1)
normalized2 = np.asarray(normalized2)
return np.sum(normalized1 * normalized2, axis=1)


random_encoder_baseline = ModelMeta(
Expand Down
8 changes: 8 additions & 0 deletions mteb/models/search_encoder_index/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .search_backend_protocol import IndexEncoderSearchProtocol
from .search_indexes import FaissSearchIndex, StreamingSearchIndex

__all__ = [
"FaissSearchIndex",
"IndexEncoderSearchProtocol",
"StreamingSearchIndex",
]
50 changes: 50 additions & 0 deletions mteb/models/search_encoder_index/search_backend_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections.abc import Callable
from typing import Protocol

from mteb.types import Array, TopRankedDocumentsType


class IndexEncoderSearchProtocol(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class IndexEncoderSearchProtocol(Protocol):
class EncoderSearchProtocol(Protocol):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can specify that this is only for index and only for encoder, because this can be confused that SentenceTransformerEncoderWrapper will implement it (probably)

"""Protocol for search backends used in encoder-based retrieval."""

def add_document(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_document(
def add_documents(

self,
embeddings: Array,
idxs: list[str],
) -> None:
"""Add documents to the search backend.

Args:
embeddings: Embeddings of the documents to add.
idxs: IDs of the documents to add.
"""

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search through added corpus embeddings or rerank top-ranked documents.

Supports both full-corpus and reranking search modes:
- Full-corpus mode: `top_ranked=None`, uses added corpus embeddings.
- Reranking mode: `top_ranked` contains mapping {query_id: [doc_ids]}.

Args:
embeddings: Query embeddings, shape (num_queries, dim).
top_k: Number of top results to return.
similarity_fn: Function to compute similarity between query and corpus.
top_ranked: Mapping of query_id -> list of candidate doc_ids. Used for reranking.
query_idx_to_id: Mapping of query index -> query_id. Used for reranking.

Returns:
A tuple (top_k_values, top_k_indices), for each query:
- top_k_values: List of top-k similarity scores.
- top_k_indices: List of indices of the top-k documents in the added corpus.
"""

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
7 changes: 7 additions & 0 deletions mteb/models/search_encoder_index/search_indexes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .faiss_search_index import FaissSearchIndex
from .streaming_search_index import StreamingSearchIndex

__all__ = [
"FaissSearchIndex",
"StreamingSearchIndex",
]
157 changes: 157 additions & 0 deletions mteb/models/search_encoder_index/search_indexes/faiss_search_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import logging
from collections.abc import Callable

import numpy as np
import torch

from mteb._requires_package import requires_package
from mteb.models.model_meta import ScoringFunction
from mteb.models.models_protocols import EncoderProtocol
from mteb.types import Array, TopRankedDocumentsType

logger = logging.getLogger(__name__)


class FaissSearchIndex:
"""FAISS-based backend for encoder-based search.

Supports both full-corpus retrieval and reranking (via `top_ranked`).

Notes:
- Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2).
- Expects embeddings to be normalized if cosine similarity is desired.
"""

_normalize: bool = False

def __init__(self, model: EncoderProtocol) -> None:
requires_package(
self,
"faiss",
"FAISS-based search",
install_instruction="pip install mteb[faiss-cpu]",
)

import faiss
from faiss import IndexFlatIP, IndexFlatL2

# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT:
self.index_type = IndexFlatIP
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE:
self.index_type = IndexFlatIP
self._normalize = True
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN:
self.index_type = IndexFlatL2
else:
raise ValueError(
f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. "
f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}."
)

self.idxs: list[str] = []
self.index: faiss.Index | None = None

def add_document(self, embeddings: Array, idxs: list[str]) -> None:
"""Add all document embeddings and their IDs to FAISS index."""
import faiss

if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.detach().cpu().numpy()

embeddings = embeddings.astype(np.float32)
self.idxs.extend(idxs)

if self._normalize:
faiss.normalize_L2(embeddings)

dim = embeddings.shape[1]
if self.index is None:
self.index = self.index_type(dim)

self.index.add(embeddings)
logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.")

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search using FAISS."""
import faiss

if self.index is None:
raise ValueError("No index built. Call add_document() first.")

if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.detach().cpu().numpy()

if self._normalize:
faiss.normalize_L2(embeddings)

if top_ranked is not None:
if query_idx_to_id is None:
raise ValueError("query_idx_to_id must be provided when reranking.")

similarities, ids = self._reranking(
embeddings,
top_k,
top_ranked=top_ranked,
query_idx_to_id=query_idx_to_id,
)
else:
similarities, ids = self.index.search(embeddings.astype(np.float32), top_k)
similarities = similarities.tolist()
ids = ids.tolist()

if issubclass(self.index_type, faiss.IndexFlatL2):
similarities = -np.sqrt(np.maximum(similarities, 0))

return similarities, ids

def _reranking(
self,
embeddings: Array,
top_k: int,
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
scores_all: list[list[float]] = []
idxs_all: list[list[int]] = []

for query_idx, query_emb in enumerate(embeddings):
query_id = query_idx_to_id[query_idx]
ranked_ids = top_ranked.get(query_id)
if not ranked_ids:
logger.warning(f"No top-ranked documents for query {query_id}")
scores_all.append([])
idxs_all.append([])
continue

candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
d = self.index.d
candidate_embs = np.vstack(
[self.index.reconstruct(idx) for idx in candidate_indices]
)
sub_reranking_index = self.index_type(d)
sub_reranking_index.add(candidate_embs)

# Search returns scores and indices in one call
scores, local_indices = sub_reranking_index.search(
query_emb.reshape(1, -1).astype(np.float32),
min(top_k, len(candidate_indices)),
)
# faiss will output 2d arrays even for single query
scores_all.append(scores[0].tolist())
idxs_all.append(local_indices[0].tolist())

return scores_all, idxs_all

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
self.index = None
self.idxs = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from collections.abc import Callable

import torch

from mteb.types import Array, TopRankedDocumentsType

logger = logging.getLogger(__name__)


class StreamingSearchIndex:
"""Streaming backend for encoder-based search.

- Does not store the entire corpus in memory.
- Encodes and searches corpus in chunks.
"""

sub_corpus_embeddings: Array | None = None
idxs: list[str]

def add_document(
self,
embeddings: Array,
idxs: list[str],
) -> None:
"""Add all document embeddings and their IDs to the backend."""
self.sub_corpus_embeddings = embeddings
self.idxs = idxs

def search(
self,
embeddings: Array,
top_k: int,
similarity_fn: Callable[[Array, Array], Array],
top_ranked: TopRankedDocumentsType | None = None,
query_idx_to_id: dict[int, str] | None = None,
) -> tuple[list[list[float]], list[list[int]]]:
"""Search through added corpus embeddings or rerank top-ranked documents."""
if self.sub_corpus_embeddings is None:
raise ValueError("No corpus embeddings found. Did you call add_document()?")

if top_ranked is not None:
if query_idx_to_id is None:
raise ValueError("query_idx_to_id is required when using top_ranked.")

scores_all: list[list[float]] = []
idxs_all: list[list[int]] = []

doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}

for query_idx, query_emb in enumerate(embeddings):
query_id = query_idx_to_id[query_idx]
ranked_ids = top_ranked.get(query_id)
if not ranked_ids:
logger.warning(f"No top-ranked docs for query {query_id}")
scores_all.append([])
idxs_all.append([])
continue

candidate_idx = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
candidate_embs = self.sub_corpus_embeddings[candidate_idx]

scores = similarity_fn(
torch.as_tensor(query_emb).unsqueeze(0),
torch.as_tensor(candidate_embs),
)

values, indices = torch.topk(
torch.as_tensor(scores),
k=min(top_k, len(candidate_idx)),
dim=1,
largest=True,
)
scores_all.append(values.squeeze(0).cpu().tolist())
idxs_all.append(indices.squeeze(0).cpu().tolist())

return scores_all, idxs_all

scores = similarity_fn(embeddings, self.sub_corpus_embeddings)
self.sub_corpus_embeddings = None

cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
torch.tensor(scores),
min(
top_k + 1,
len(scores[1]) if len(scores) > 1 else len(scores[-1]),
),
dim=1,
largest=True,
)
return (
cos_scores_top_k_values.cpu().tolist(),
cos_scores_top_k_idx.cpu().tolist(),
)

def clear(self) -> None:
"""Clear all stored documents and embeddings from the backend."""
self.sub_corpus_embeddings = None
self.idxs = []
Loading