Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# Standard library imports
import os
import tempfile
from functools import partial
from unittest import TestCase
from unittest.mock import patch

# Third-party imports
import numpy as np
import pytest

# Hugging Face Datasets imports
from datasets.arrow_dataset import Dataset
from datasets.search import ElasticSearchIndex, FaissIndex, MissingIndex

# Local test utilities to conditionally run tests
from .utils import require_elasticsearch, require_faiss


# Mark all tests in this file as integration tests
pytestmark = pytest.mark.integration


Expand All @@ -26,35 +31,58 @@ def test_add_faiss_index(self):
import faiss

dset: Dataset = self._create_dummy_dataset()
# Add a vector column where each row has a deterministic vector
dset = dset.map(
lambda ex, i: {"vecs": i * np.ones(5, dtype=np.float32)}, with_indices=True, keep_in_memory=True
)

# Create a FAISS index on the vector column
dset = dset.add_faiss_index("vecs", batch_size=100, metric_type=faiss.METRIC_INNER_PRODUCT)

# Query the index with a vector of ones
scores, examples = dset.get_nearest_examples("vecs", np.ones(5, dtype=np.float32))

# The vector with the highest inner product should be the last one
self.assertEqual(examples["filename"][0], "my_name-train_29")

# Clean up by removing the index
dset.drop_index("vecs")

def test_add_faiss_index_errors(self):
import faiss

dset: Dataset = self._create_dummy_dataset()

# String columns cannot be indexed with FAISS
with pytest.raises(ValueError, match="Wrong feature type for column 'filename'"):
_ = dset.add_faiss_index("filename", batch_size=100, metric_type=faiss.METRIC_INNER_PRODUCT)

def test_add_faiss_index_from_external_arrays(self):
"""
Test adding a FAISS index using externally provided
numpy arrays instead of dataset columns.
"""
import faiss

dset: Dataset = self._create_dummy_dataset()

# External vectors are constructed independently of the dataset
dset.add_faiss_index_from_external_arrays(
external_arrays=np.ones((30, 5)) * np.arange(30).reshape(-1, 1),
index_name="vecs",
batch_size=100,
metric_type=faiss.METRIC_INNER_PRODUCT,
)

# Perform a nearest-neighbor query
scores, examples = dset.get_nearest_examples("vecs", np.ones(5, dtype=np.float32))
self.assertEqual(examples["filename"][0], "my_name-train_29")

def test_serialization(self):
"""
Verify that FAISS indexes can be serialized to disk
and reloaded correctly.
"""
import faiss

dset: Dataset = self._create_dummy_dataset()
Expand All @@ -73,21 +101,34 @@ def test_serialization(self):
dset.load_faiss_index("vecs2", tmp_file.name)
os.unlink(tmp_file.name)

# Validate loaded index
scores, examples = dset.get_nearest_examples("vecs2", np.ones(5, dtype=np.float32))
self.assertEqual(examples["filename"][0], "my_name-train_29")

def test_drop_index(self):
"""
Ensure that dropping an index removes it fully
and raises an error on subsequent access.
"""
dset: Dataset = self._create_dummy_dataset()
dset.add_faiss_index_from_external_arrays(
external_arrays=np.ones((30, 5)) * np.arange(30).reshape(-1, 1), index_name="vecs"
)
dset.drop_index("vecs")

# Accessing a dropped index should raise MissingIndex
self.assertRaises(MissingIndex, partial(dset.get_nearest_examples, "vecs2", np.ones(5, dtype=np.float32)))

def test_add_elasticsearch_index(self):
"""
Validate Elasticsearch-based indexing and querying
using mocked Elasticsearch APIs.
"""
from elasticsearch import Elasticsearch

dset: Dataset = self._create_dummy_dataset()

# Mock Elasticsearch calls to avoid real network usage
with (
patch("elasticsearch.Elasticsearch.search") as mocked_search,
patch("elasticsearch.client.IndicesClient.create") as mocked_index_create,
Expand All @@ -105,15 +146,26 @@ def test_add_elasticsearch_index(self):

@require_faiss
class FaissIndexTest(TestCase):
"""
Unit tests for the FaissIndex abstraction.
"""

def test_flat_ip(self):
"""
Test FAISS flat inner-product index creation,
search, and batch search behavior.
"""

import faiss

index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT)

# add vectors
# Add identity vectors
index.add_vectors(np.eye(5, dtype=np.float32))
self.assertIsNotNone(index.faiss_index)
self.assertEqual(index.faiss_index.ntotal, 5)

# Add additional zero vectors
index.add_vectors(np.zeros((5, 5), dtype=np.float32))
self.assertEqual(index.faiss_index.ntotal, 10)

Expand All @@ -135,6 +187,9 @@ def test_flat_ip(self):
self.assertListEqual([4, 3, 2, 1, 0], best_indices)

def test_factory(self):
"""
Verify FAISS index creation via string factory.
"""
import faiss

index = FaissIndex(string_factory="Flat")
Expand All @@ -143,10 +198,15 @@ def test_factory(self):
index = FaissIndex(string_factory="LSH")
index.add_vectors(np.eye(5, dtype=np.float32))
self.assertIsInstance(index.faiss_index, faiss.IndexLSH)

# Prevent conflicting index definitions
with self.assertRaises(ValueError):
_ = FaissIndex(string_factory="Flat", custom_index=faiss.IndexFlat(5))

def test_custom(self):
"""
Test usage of a user-provided FAISS index.
"""
import faiss

custom_index = faiss.IndexFlat(5)
Expand All @@ -155,6 +215,10 @@ def test_custom(self):
self.assertIsInstance(index.faiss_index, faiss.IndexFlat)

def test_serialization(self):
"""
Ensure FAISS index serialization and deserialization
preserves functionality.
"""
import faiss

index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT)
Expand Down