Skip to content

Commit

Permalink
Adding Filtered Index support to Python bindings (#482)
Browse files Browse the repository at this point in the history
* Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately

* Revert "Halfway approach to the new indexfactory, but it doesn't have the same featureset as the old way. Committing this for posterity but reverting my changes ultimately"

This reverts commit 03dccb5.

* Adding filtered search. API is going to change still.

* Further enhancements to the new filter capability in the static memory index.

* Ran automatic formatting

* Fixing my logic and ensuring the unit tests pass.

* Setting this up as a rc build first

* list[list[Hashable]] -> list[list[str]]

* Adding halfway to a solution where we query for more items than exist in the filter set. We need to replicate this behavior across all indices though - dynamic, static disk and memory w/o filters, etc

* Removing the import of Hashable too
  • Loading branch information
daxpryce committed Nov 7, 2023
1 parent 179927e commit 4a57e89
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "diskannpy"
version = "0.6.1"
version = "0.7.0rc1"

description = "DiskANN Python extension module"
readme = "python/README.md"
Expand Down
5 changes: 3 additions & 2 deletions python/include/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ template <typename DT, typename TagT = DynamicIdType, typename LabelT = filterT>
void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity,
float alpha, uint32_t num_threads, bool use_pq_build,
size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity,
bool use_tags = false);
size_t num_pq_bytes, bool use_opq, bool use_tags = false,
const std::string& filter_labels_file = "", const std::string& universal_label = "",
uint32_t filter_complexity = 0);

}
4 changes: 4 additions & 0 deletions python/include/static_memory_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ template <typename DT> class StaticMemoryIndex
NeighborsAndDistances<StaticIdType> search(py::array_t<DT, py::array::c_style | py::array::forcecast> &query,
uint64_t knn, uint64_t complexity);

NeighborsAndDistances<StaticIdType> search_with_filter(
py::array_t<DT, py::array::c_style | py::array::forcecast> &query, uint64_t knn, uint64_t complexity,
filterT filter);

NeighborsAndDistances<StaticIdType> batch_search(
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, uint64_t num_queries, uint64_t knn,
uint64_t complexity, uint32_t num_threads);
Expand Down
51 changes: 45 additions & 6 deletions python/src/_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import json
import os
import shutil
from pathlib import Path
Expand Down Expand Up @@ -174,8 +175,10 @@ def build_memory_index(
num_pq_bytes: int = defaults.NUM_PQ_BYTES,
use_opq: bool = defaults.USE_OPQ,
vector_dtype: Optional[VectorDType] = None,
filter_complexity: int = defaults.FILTER_COMPLEXITY,
tags: Union[str, VectorIdentifierBatch] = "",
filter_labels: Optional[list[list[str]]] = None,
universal_label: str = "",
filter_complexity: int = defaults.FILTER_COMPLEXITY,
index_prefix: str = "ann",
) -> None:
"""
Expand Down Expand Up @@ -223,10 +226,20 @@ def build_memory_index(
Default is `0`.
- **use_opq**: Use optimized product quantization during build.
- **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array.
- **filter_complexity**: Complexity to use when using filters. Default is 0.
- **tags**: A `str` representing a path to a pre-built tags file on disk, or a `numpy.ndarray` of uint32 ids
corresponding to the ordinal position of the vectors provided to build the index. Defaults to "". **This value
must be provided if you want to build a memory index intended for use with `diskannpy.DynamicMemoryIndex`**.
- **tags**: Tags can be defined either as a path on disk to an existing .tags file, or provided as a np.array of
the same length as the number of vectors. Tags are used to identify vectors in the index via your *own*
numbering conventions, and is absolutely required for loading DynamicMemoryIndex indices `from_file`.
- **filter_labels**: An optional, but exhaustive list of categories for each vector. This is used to filter
search results by category. If provided, this must be a list of lists, where each inner list is a list of
categories for the corresponding vector. For example, if you have 3 vectors, and the first vector belongs to
categories "a" and "b", the second vector belongs to category "b", and the third vector belongs to no categories,
you would provide `filter_labels=[["a", "b"], ["b"], []]`. If you do not want to provide categories for a
particular vector, you can provide an empty list. If you do not want to provide categories for any vectors,
you can provide `None` for this parameter (which is the default)
- **universal_label**: An optional label that indicates that this vector should be included in *every* search
in which it also meets the knn search criteria.
- **filter_complexity**: Complexity to use when using filters. Default is 0. 0 is strictly invalid if you are
using filters.
- **index_prefix**: The prefix of the index files. Defaults to "ann".
"""
_assert(
Expand All @@ -245,6 +258,10 @@ def build_memory_index(
_assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes")
_assert_is_nonnegative_uint32(filter_complexity, "filter_complexity")
_assert(index_prefix != "", "index_prefix cannot be an empty string")
_assert(
filter_labels is None or filter_complexity > 0,
"if filter_labels is provided, filter_complexity must not be 0"
)

index_path = Path(index_directory)
_assert(
Expand All @@ -262,6 +279,11 @@ def build_memory_index(
)

num_points, dimensions = vectors_metadata_from_file(vector_bin_path)
if filter_labels is not None:
_assert(
len(filter_labels) == num_points,
"filter_labels must be the same length as the number of points"
)

if vector_dtype_actual == np.uint8:
_builder = _native_dap.build_memory_uint8_index
Expand All @@ -272,6 +294,21 @@ def build_memory_index(

index_prefix_path = os.path.join(index_directory, index_prefix)

filter_labels_file = ""
if filter_labels is not None:
label_counts = {}
filter_labels_file = f"{index_prefix_path}_pylabels.txt"
with open(filter_labels_file, "w") as labels_file:
for labels in filter_labels:
for label in labels:
label_counts[label] = 1 if label not in label_counts else label_counts[label] + 1
if len(labels) == 0:
print("default", file=labels_file)
else:
print(",".join(labels), file=labels_file)
with open(f"{index_prefix_path}_label_metadata.json", "w") as label_metadata_file:
json.dump(label_counts, label_metadata_file, indent=True)

if isinstance(tags, str) and tags != "":
use_tags = True
shutil.copy(tags, index_prefix_path + ".tags")
Expand Down Expand Up @@ -299,8 +336,10 @@ def build_memory_index(
use_pq_build=use_pq_build,
num_pq_bytes=num_pq_bytes,
use_opq=use_opq,
filter_complexity=filter_complexity,
use_tags=use_tags,
filter_labels_file=filter_labels_file,
universal_label=universal_label,
filter_complexity=filter_complexity,
)

_write_index_metadata(
Expand Down
12 changes: 6 additions & 6 deletions python/src/_builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def build_memory_index(
use_pq_build: bool,
num_pq_bytes: int,
use_opq: bool,
label_file: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels: Optional[list[list[str]]],
universal_label: str,
filter_complexity: int,
tags: Optional[VectorIdentifierBatch],
index_prefix: str,
index_prefix: str
) -> None: ...
@overload
def build_memory_index(
Expand All @@ -66,9 +66,9 @@ def build_memory_index(
num_pq_bytes: int,
use_opq: bool,
vector_dtype: VectorDType,
label_file: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels_file: Optional[list[list[str]]],
universal_label: str,
filter_complexity: int,
tags: Optional[str],
index_prefix: str,
index_prefix: str
) -> None: ...
22 changes: 12 additions & 10 deletions python/src/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _ensure_index_metadata(
distance_metric: Optional[DistanceMetric],
max_vectors: int,
dimensions: Optional[int],
warn_size_exceeded: bool = False,
) -> Tuple[VectorDType, str, np.uint64, np.uint64]:
possible_metadata = _read_index_metadata(index_path_and_prefix)
if possible_metadata is None:
Expand All @@ -226,16 +227,17 @@ def _ensure_index_metadata(
return vector_dtype, distance_metric, max_vectors, dimensions # type: ignore
else:
vector_dtype, distance_metric, num_vectors, dimensions = possible_metadata
if max_vectors is not None and num_vectors > max_vectors:
warnings.warn(
"The number of vectors in the saved index exceeds the max_vectors parameter. "
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
)
max_vectors = num_vectors
if num_vectors == max_vectors:
warnings.warn(
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
)
if warn_size_exceeded:
if max_vectors is not None and num_vectors > max_vectors:
warnings.warn(
"The number of vectors in the saved index exceeds the max_vectors parameter. "
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
)
max_vectors = num_vectors
if num_vectors == max_vectors:
warnings.warn(
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
)
return possible_metadata


Expand Down
2 changes: 1 addition & 1 deletion python/src/_dynamic_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def from_file(
f"The file {tags_file} does not exist in {index_directory}",
)
vector_dtype, dap_metric, num_vectors, dimensions = _ensure_index_metadata(
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions, warn_size_exceeded=True
)

index = cls(
Expand Down
45 changes: 42 additions & 3 deletions python/src/_static_memory_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import json
import os
import warnings
from typing import Optional
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
distance_metric: Optional[DistanceMetric] = None,
vector_dtype: Optional[VectorDType] = None,
dimensions: Optional[int] = None,
enable_filters: bool = False
):
"""
### Parameters
Expand Down Expand Up @@ -73,8 +75,22 @@ def __init__(
- **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same
dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it
does not exist, you are required to provide it.
- **enable_filters**: Indexes built with filters can also be used for filtered search.
"""
index_prefix = _valid_index_prefix(index_directory, index_prefix)
self._labels_map = {}
self._labels_metadata = {}
if enable_filters:
try:
with open(index_prefix + "_labels_map.txt", "r") as labels_map_if:
for line in labels_map_if:
(key, val) = line.split("\t")
self._labels_map[key] = int(val)
with open(f"{index_prefix}_label_metadata.json", "r") as labels_metadata_if:
self._labels_metadata = json.load(labels_metadata_if)
except: # noqa: E722
# exceptions are basically presumed to be either file not found or file not formatted correctly
raise RuntimeException("Filter labels file was unable to be processed.")
vector_dtype, metric, num_points, dims = _ensure_index_metadata(
index_prefix,
vector_dtype,
Expand Down Expand Up @@ -109,7 +125,7 @@ def __init__(
)

def search(
self, query: VectorLike, k_neighbors: int, complexity: int
self, query: VectorLike, k_neighbors: int, complexity: int, filter_label: str = ""
) -> QueryResponse:
"""
Searches the index by a single query vector.
Expand All @@ -121,13 +137,25 @@ def search(
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
"""
if filter_label != "":
if len(self._labels_map) == 0:
raise ValueError(
f"A filter label of {filter_label} was provided, but this class was not initialized with filters "
"enabled, e.g. StaticDiskMemory(..., enable_filters=True)"
)
if filter_label not in self._labels_map:
raise ValueError(
f"A filter label of {filter_label} was provided, but the external(str)->internal(np.uint32) labels map "
f"does not include that label."
)
k_neighbors = min(k_neighbors, self._labels_metadata[filter_label])
_query = _castable_dtype_or_raise(query, expected=self._vector_dtype)
_assert(len(_query.shape) == 1, "query vector must be 1-d")
_assert(
_query.shape[0] == self._dimensions,
f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, "
f"query dimensionality: {_query.shape[0]}",
)
)
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
_assert_is_nonnegative_uint32(complexity, "complexity")

Expand All @@ -136,9 +164,20 @@ def search(
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)

if filter_label == "":
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
else:
filter = self._labels_map[filter_label]
neighbors, distances = self._index.search_with_filter(
query=query,
knn=k_neighbors,
complexity=complexity,
filter=filter
)
return QueryResponse(identifiers=neighbors, distances=distances)


def batch_search(
self,
queries: VectorLikeBatch,
Expand Down
60 changes: 53 additions & 7 deletions python/src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,37 @@ template void build_disk_index<uint8_t>(diskann::Metric, const std::string &, co
template void build_disk_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
double, double, uint32_t, uint32_t);

template <typename T, typename TagT, typename LabelT>
std::string prepare_filtered_label_map(diskann::Index<T, TagT, LabelT> &index, const std::string &index_output_path,
const std::string &filter_labels_file, const std::string &universal_label)
{
std::string labels_file_to_use = index_output_path + "_label_formatted.txt";
std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt";
convert_labels_string_to_int(filter_labels_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
if (!universal_label.empty())
{
uint32_t unv_label_as_num = 0;
index.set_universal_label(unv_label_as_num);
}
return labels_file_to_use;
}

template std::string prepare_filtered_label_map<float>(diskann::Index<float, uint32_t, uint32_t> &, const std::string &,
const std::string &, const std::string &);

template std::string prepare_filtered_label_map<int8_t>(diskann::Index<int8_t, uint32_t, uint32_t> &,
const std::string &, const std::string &, const std::string &);

template std::string prepare_filtered_label_map<uint8_t>(diskann::Index<uint8_t, uint32_t, uint32_t> &,
const std::string &, const std::string &, const std::string &);

template <typename T, typename TagT, typename LabelT>
void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity,
const float alpha, const uint32_t num_threads, const bool use_pq_build,
const size_t num_pq_bytes, const bool use_opq, const uint32_t filter_complexity,
const bool use_tags)
const size_t num_pq_bytes, const bool use_opq, const bool use_tags,
const std::string &filter_labels_file, const std::string &universal_label,
const uint32_t filter_complexity)
{
diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree)
.with_filter_list_size(filter_complexity)
Expand Down Expand Up @@ -65,23 +90,44 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
size_t tag_dims = 1;
diskann::load_bin(tags_file, tags_data, data_num, tag_dims);
std::vector<TagT> tags(tags_data, tags_data + data_num);
index.build(vector_bin_path.c_str(), data_num, tags);
if (filter_labels_file.empty())
{
index.build(vector_bin_path.c_str(), data_num, tags);
}
else
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags);
}
}
else
{
index.build(vector_bin_path.c_str(), data_num);
if (filter_labels_file.empty())
{
index.build(vector_bin_path.c_str(), data_num);
}
else
{
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file,
universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num);
}
}

index.save(index_output_path.c_str());
}

template void build_memory_index<float>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool, bool, const std::string &,
const std::string &, uint32_t);

template void build_memory_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool, bool, const std::string &,
const std::string &, uint32_t);

template void build_memory_index<uint8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool, bool, const std::string &,
const std::string &, uint32_t);

} // namespace diskannpy
1 change: 0 additions & 1 deletion python/src/diskann_bindings.cpp

This file was deleted.

Loading

0 comments on commit 4a57e89

Please sign in to comment.