Skip to content

Commit f4f98bc

Browse files
authored
COH-32073 - Add support for HnswIndex in ai module (#232)
* COH-32073 - Add support for HnswIndex in ai module
1 parent fc20f96 commit f4f98bc

File tree

3 files changed

+149
-28
lines changed

3 files changed

+149
-28
lines changed

src/coherence/ai.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import base64
88
from abc import ABC
99
from collections import OrderedDict
10-
from typing import Any, Dict, List, Optional, TypeVar, Union, cast
10+
from typing import Any, Dict, Final, List, Optional, TypeVar, Union, cast
1111

1212
import jsonpickle
1313
import numpy as np
@@ -337,6 +337,93 @@ def __init__(self, extractor: Union[ValueExtractor[T, E], str], over_sampling_fa
337337
self.oversamplingFactor = over_sampling_factor
338338

339339

340+
@proxy("coherence.hnsw.HnswIndex")
341+
class HnswIndex(AbstractEvolvable):
342+
DEFAULT_SPACE_NAME: Final[str] = "COSINE"
343+
"""The default index space name."""
344+
345+
DEFAULT_MAX_ELEMENTS: Final[int] = 4096
346+
"""
347+
The default maximum number of elements the index can contain is 4096
348+
but the index will grow automatically by doubling its capacity until it
349+
reaches approximately 8m elements, at which point it will grow by 50%
350+
whenever it gets full.
351+
"""
352+
353+
DEFAULT_M: Final[int] = 16
354+
"""
355+
The default number of bidirectional links created for every new
356+
element during construction is 2-100. Higher M work better on datasets
357+
with high intrinsic dimensionality and/or high recall, while low M work
358+
better for datasets with low intrinsic dimensionality and/or low recalls.
359+
The parameter also determines the algorithm's memory consumption,
360+
which is roughly M * 8-10 bytes per stored element. As an example for
361+
dim=4 random vectors optimal M for search is somewhere around 6,
362+
while for high dimensional datasets (word embeddings, good face
363+
descriptors), higher M are required (e.g. M=48-64) for optimal
364+
performance at high recall. The range M=12-48 is ok for the most of the
365+
use cases. When M is changed one has to update the other parameters.
366+
Nonetheless, ef and ef_construction parameters can be roughly estimated
367+
by assuming that M*ef_{construction} is a constant. The default value is
368+
16.
369+
"""
370+
371+
DEFAULT_EF_CONSTRUCTION: Final[int] = 200
372+
"""
373+
The parameter has the same meaning as ef, which controls the
374+
index_time/index_accuracy. Bigger ef_construction leads to longer
375+
construction, but better index quality. At some point, increasing
376+
ef_construction does not improve the quality of the index. One way to
377+
check if the selection of ef_construction was ok is to measure a recall
378+
for M nearest neighbor search when ef =ef_construction: if the recall is
379+
lower than 0.9, than there is room for improvement. The default value is
380+
200.
381+
"""
382+
383+
DEFAULT_EF_SEARCH: Final[int] = 50
384+
"""
385+
The parameter controlling query time/accuracy trade-off. The default
386+
value is 50.
387+
"""
388+
389+
DEFAULT_RANDOM_SEED: Final[int] = 100
390+
"""The default random seed used for the index."""
391+
392+
def __init__(
393+
self,
394+
extractor: Union[ValueExtractor[T, E], str],
395+
dimensions: int,
396+
space_name: str = DEFAULT_SPACE_NAME,
397+
max_elements: int = DEFAULT_MAX_ELEMENTS,
398+
m: int = DEFAULT_M,
399+
ef_construction: int = DEFAULT_EF_CONSTRUCTION,
400+
ef_search: int = DEFAULT_EF_SEARCH,
401+
random_seed: int = DEFAULT_RANDOM_SEED,
402+
) -> None:
403+
"""
404+
Creates an instance of HnswIndex class.
405+
406+
:param extractor: The ValueExtractor to use to extract the Vector.
407+
:param dimensions: The number of dimensions in the vector.
408+
:param space_name: The index space name.
409+
:param max_elements: The maximum number of elements the index can contain.
410+
:param m: The number of bidirectional links created for every new element during construction.
411+
:param ef_construction: The parameter controlling the index_time/index_accuracy.
412+
:param ef_search: The parameter controlling query time/accuracy trade-off.
413+
:param random_seed: The random seed used for the index.
414+
"""
415+
416+
super().__init__()
417+
self.extractor = extractor
418+
self.dimensions = dimensions
419+
self.spaceName = space_name if space_name else ""
420+
self.maxElements = max_elements
421+
self.m = m
422+
self.efConstruction = ef_construction
423+
self.efSearch = ef_search
424+
self.randomSeed = random_seed
425+
426+
340427
class Vectors:
341428

342429
EPSILON = 1e-30 # Python automatically handles float precision

tests/e2e/test_ai.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from coherence import COH_LOG, Extractors, NamedCache, Session
11-
from coherence.ai import BinaryQuantIndex, DocumentChunk, FloatVector, SimilaritySearch, Vectors
11+
from coherence.ai import BinaryQuantIndex, DocumentChunk, FloatVector, HnswIndex, SimilaritySearch, Vectors
1212

1313

1414
class ValueWithVector:
@@ -94,9 +94,49 @@ async def populate_document_chunk_vectors(vectors: NamedCache[int, DocumentChunk
9494

9595
@pytest.mark.asyncio
9696
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
97-
async def test_similarity_search_with_index(test_session: Session) -> None:
97+
async def test_similarity_search_with_binary_quant_index(test_session: Session) -> None:
98+
await _run_similarity_search_with_index(test_session, "BinaryQuantIndex")
99+
100+
101+
@pytest.mark.asyncio
102+
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
103+
async def test_similarity_search_with_document_chunk(test_session: Session) -> None:
104+
cache: NamedCache[int, DocumentChunk] = await test_session.get_cache("vector_cache")
105+
dc: DocumentChunk = await populate_document_chunk_vectors(cache)
106+
107+
# Create a SimilaritySearch aggregator
108+
value_extractor = Extractors.extract("vector")
109+
k = 10
110+
ss = SimilaritySearch(value_extractor, dc.vector, k)
111+
112+
hnsw_result = await cache.aggregate(ss)
113+
114+
assert hnsw_result is not None
115+
assert len(hnsw_result) == k
116+
COH_LOG.info("Results below for test_SimilaritySearch_with_DocumentChunk:")
117+
for e in hnsw_result:
118+
COH_LOG.info(e)
119+
120+
await cache.truncate()
121+
await cache.destroy()
122+
123+
124+
@pytest.mark.asyncio
125+
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
126+
async def test_similarity_search_with_hnsw_index(test_session: Session) -> None:
127+
await _run_similarity_search_with_index(test_session, "HnswIndex")
128+
129+
130+
async def _run_similarity_search_with_index(test_session: Session, index_type: str) -> None:
98131
cache: NamedCache[int, ValueWithVector] = await test_session.get_cache("vector_cache")
99-
cache.add_index(BinaryQuantIndex(Extractors.extract("vector")))
132+
if index_type == "BinaryQuantIndex":
133+
cache.add_index(BinaryQuantIndex(Extractors.extract("vector")))
134+
elif index_type == "HnswIndex":
135+
cache.add_index(HnswIndex(Extractors.extract("vector"), DIMENSIONS))
136+
else:
137+
COH_LOG.error("NO index_type specified")
138+
return
139+
100140
value_with_vector = await populate_vectors(cache)
101141

102142
# Create a SimilaritySearch aggregator
@@ -122,7 +162,7 @@ async def test_similarity_search_with_index(test_session: Session) -> None:
122162
hnsw_result = await cache.aggregate(ss)
123163
end_time = time.perf_counter()
124164
elapsed_time = end_time - start_time
125-
COH_LOG.info("Results below for test_SimilaritySearch with Index:")
165+
COH_LOG.info("Results below for test_SimilaritySearch with HnswIndex:")
126166
for e in hnsw_result:
127167
COH_LOG.info(e)
128168
COH_LOG.info(f"Elapsed time: {elapsed_time} seconds")
@@ -132,26 +172,3 @@ async def test_similarity_search_with_index(test_session: Session) -> None:
132172

133173
await cache.truncate()
134174
await cache.destroy()
135-
136-
137-
@pytest.mark.asyncio
138-
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
139-
async def test_similarity_search_with_document_chunk(test_session: Session) -> None:
140-
cache: NamedCache[int, DocumentChunk] = await test_session.get_cache("vector_cache")
141-
dc: DocumentChunk = await populate_document_chunk_vectors(cache)
142-
143-
# Create a SimilaritySearch aggregator
144-
value_extractor = Extractors.extract("vector")
145-
k = 10
146-
ss = SimilaritySearch(value_extractor, dc.vector, k)
147-
148-
hnsw_result = await cache.aggregate(ss)
149-
150-
assert hnsw_result is not None
151-
assert len(hnsw_result) == k
152-
COH_LOG.info("Results below for test_SimilaritySearch_with_DocumentChunk:")
153-
for e in hnsw_result:
154-
COH_LOG.info(e)
155-
156-
await cache.truncate()
157-
await cache.destroy()

tests/unit/test_serialization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CosineDistance,
1616
DocumentChunk,
1717
FloatVector,
18+
HnswIndex,
1819
QueryResult,
1920
SimilaritySearch,
2021
)
@@ -253,3 +254,19 @@ def test_binary_quant_index_serialization() -> None:
253254

254255
o = s.deserialize(ser)
255256
assert isinstance(o, BinaryQuantIndex)
257+
258+
259+
# noinspection PyUnresolvedReferences
260+
def test_HnswIndex_serialization() -> None:
261+
bqi = HnswIndex(Extractors.extract("foo"), 384)
262+
ser = s.serialize(bqi)
263+
assert ser == (
264+
b'\x15{"@class": "coherence.hnsw.HnswIndex", "dataVersion": 0, '
265+
b'"binFuture": null, "extractor": {"@class": "extractor.UniversalExtractor", '
266+
b'"name": "foo", "params": null}, "dimensions": 384, "spaceName": "COSINE", '
267+
b'"maxElements": 4096, "m": 16, "efConstruction": 200, "efSearch": 50, '
268+
b'"randomSeed": 100}'
269+
)
270+
271+
o = s.deserialize(ser)
272+
assert isinstance(o, HnswIndex)

0 commit comments

Comments
 (0)