Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] Add unittests for InformationRetrievalEvaluator #2838

Merged
merged 6 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/evaluation/test_binary_classification_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator
"""

from __future__ import annotations

import numpy as np
from sklearn.metrics import accuracy_score, f1_score

from sentence_transformers import (
evaluation,
)


def test_BinaryClassificationEvaluator_find_best_f1_and_threshold() -> None:
"""Tests that the F1 score for the computed threshold is correct"""
y_true = np.random.randint(0, 2, 1000)
y_pred_cosine = np.random.randn(1000)
(
best_f1,
best_precision,
best_recall,
threshold,
) = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(
y_pred_cosine, y_true, high_score_more_similar=True
)
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
sklearn_f1score = f1_score(y_true, y_pred_labels)
assert np.abs(best_f1 - sklearn_f1score) < 1e-6


def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold() -> None:
"""Tests that the Acc score for the computed threshold is correct"""
y_true = np.random.randint(0, 2, 1000)
y_pred_cosine = np.random.randn(1000)
(
max_acc,
threshold,
) = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(
y_pred_cosine, y_true, high_score_more_similar=True
)
y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine]
sklearn_acc = accuracy_score(y_true, y_pred_labels)
assert np.abs(max_acc - sklearn_acc) < 1e-6
141 changes: 141 additions & 0 deletions tests/evaluation/test_information_retrieval_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

from unittest.mock import Mock, PropertyMock

import pytest
import torch

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator


@pytest.fixture
def mock_model():
def mock_encode(sentences: str | list[str], **kwargs) -> torch.Tensor:
"""
We simply one-hot encode the sentences; if a sentence contains a keyword, the corresponding one-hot
encoding is added to the sentence embedding.
"""
one_hot_encodings = {
"pokemon": torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]),
"car": torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0]),
"vehicle": torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]),
"fruit": torch.tensor([0.0, 0.0, 0.0, 1.0, 0.0]),
"vegetable": torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0]),
}
if isinstance(sentences, str):
sentences = [sentences]
embeddings = []
for sentence in sentences:
encoding = torch.zeros(5)
for keyword, one_hot in one_hot_encodings.items():
if keyword in sentence:
encoding += one_hot
embeddings.append(encoding)
return torch.stack(embeddings)

model = Mock(spec=SentenceTransformer)
model.encode.side_effect = mock_encode
model.model_card_data = PropertyMock(return_value=Mock())
return model


@pytest.fixture
def test_data():
queries = {
"0": "What is a pokemon?",
"1": "What is a vegetable?",
"2": "What is a fruit?",
"3": "What is a vehicle?",
"4": "What is a car?",
}
corpus = {
"0": "A pokemon is a fictional creature",
"1": "A vegetable is a plant",
"2": "A fruit is a plant",
"3": "A vehicle is a machine",
"4": "A car is a vehicle",
}
relevant_docs = {"0": {"0"}, "1": {"1"}, "2": {"2"}, "3": {"3", "4"}, "4": {"4"}}
return queries, corpus, relevant_docs


def test_simple(test_data):
queries, corpus, relevant_docs = test_data
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

ir_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name="test",
accuracy_at_k=[1, 3],
precision_recall_at_k=[1, 3],
mrr_at_k=[3],
ndcg_at_k=[3],
map_at_k=[5],
)
results = ir_evaluator(model)
expected_keys = [
"test_cosine_accuracy@1",
"test_cosine_accuracy@3",
"test_cosine_precision@1",
"test_cosine_precision@3",
"test_cosine_recall@1",
"test_cosine_recall@3",
"test_cosine_ndcg@3",
"test_cosine_mrr@3",
"test_cosine_map@5",
"test_dot_accuracy@1",
"test_dot_accuracy@3",
"test_dot_precision@1",
"test_dot_precision@3",
"test_dot_recall@1",
"test_dot_recall@3",
"test_dot_ndcg@3",
"test_dot_mrr@3",
"test_dot_map@5",
]
assert set(results.keys()) == set(expected_keys)


def test_metrices(test_data, mock_model):
queries, corpus, relevant_docs = test_data

ir_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name="test",
accuracy_at_k=[1, 3],
precision_recall_at_k=[1, 3],
mrr_at_k=[3],
ndcg_at_k=[3],
map_at_k=[5],
)
results = ir_evaluator(mock_model)
# We expect test_cosine_precision@3 to be 0.4, since 6 out of 15 (5 queries * 3) are True Positives
# We expect test_cosine_recall@1 to be 0.9; the average of 4 times a recall of 1 and once a recall of 0.5
expected_results = {
"test_cosine_accuracy@1": 1.0,
"test_cosine_accuracy@3": 1.0,
"test_cosine_precision@1": 1.0,
"test_cosine_precision@3": 0.4,
"test_cosine_recall@1": 0.9,
"test_cosine_recall@3": 1.0,
"test_cosine_ndcg@3": 1.0,
"test_cosine_mrr@3": 1.0,
"test_cosine_map@5": 1.0,
"test_dot_accuracy@1": 1.0,
"test_dot_accuracy@3": 1.0,
"test_dot_precision@1": 1.0,
"test_dot_precision@3": 0.4,
"test_dot_recall@1": 0.9,
"test_dot_recall@3": 1.0,
"test_dot_ndcg@3": 1.0,
"test_dot_mrr@3": 1.0,
"test_dot_map@5": 1.0,
}

for key, expected_value in expected_results.items():
assert results[key] == pytest.approx(expected_value, abs=1e-9)
50 changes: 50 additions & 0 deletions tests/evaluation/test_label_accuracy_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator
"""

from __future__ import annotations

import csv
import gzip
import os

from torch.utils.data import DataLoader

from sentence_transformers import (
InputExample,
SentenceTransformer,
evaluation,
losses,
util,
)


def test_LabelAccuracyEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None:
"""Tests that the LabelAccuracyEvaluator can be loaded correctly"""
model = paraphrase_distilroberta_base_v1_model
nli_dataset_path = "datasets/AllNLI.tsv.gz"
if not os.path.exists(nli_dataset_path):
util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path)

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
dev_samples = []
with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in reader:
if row["split"] == "train":
label_id = label2int[row["label"]]
dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=label_id))
if len(dev_samples) >= 100:
break

train_loss = losses.SoftmaxLoss(
model=model,
sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
num_labels=len(label2int),
)

dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16)
evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss)
metrics = evaluator(model)
assert "accuracy" in metrics
assert metrics["accuracy"] > 0.2
24 changes: 24 additions & 0 deletions tests/evaluation/test_paraphrase_mining_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Tests the correct computation of evaluation scores from BinaryClassificationEvaluator
"""

from __future__ import annotations

from sentence_transformers import (
SentenceTransformer,
evaluation,
)


def test_ParaphraseMiningEvaluator(paraphrase_distilroberta_base_v1_model: SentenceTransformer) -> None:
"""Tests that the ParaphraseMiningEvaluator can be loaded"""
model = paraphrase_distilroberta_base_v1_model
sentences = {
0: "Hello World",
1: "Hello World!",
2: "The cat is on the table",
3: "On the table the cat is",
}
data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0, 1), (2, 3)])
metrics = data_eval(model)
assert metrics[data_eval.primary_metric] > 0.99
98 changes: 0 additions & 98 deletions tests/test_evaluator.py

This file was deleted.