Skip to content

Commit

Permalink
added in-mem index class to python bindings (#301)
Browse files Browse the repository at this point in the history
* added in-mem dynamic index class to python bindings

* added static in memory index python binding

* clang format
  • Loading branch information
harsha-simhadri authored Apr 5, 2023
1 parent d9fb349 commit 949167c
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 20 deletions.
14 changes: 12 additions & 2 deletions python/src/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
np.byte: _native_dap.DiskANNInt8Index,
}

_DTYPE_TO_NATIVE_INMEM_DYNAMIC_INDEX = {
np.single: _native_dap.DiskANNDynamicInMemFloatIndex,
np.ubyte: _native_dap.DiskANNDynamicInMemUInt8Index,
np.byte: _native_dap.DiskANNDynamicInMemInt8Index,
}

_DTYPE_TO_NATIVE_INMEM_STATIC_INDEX = {
np.single: _native_dap.DiskANNStaticInMemFloatIndex,
np.ubyte: _native_dap.DiskANNStaticInMemUInt8Index,
np.byte: _native_dap.DiskANNStaticInMemInt8Index,
}


VectorDType = TypeVar("VectorDType", Type[np.single], Type[np.ubyte], Type[np.byte])

Expand Down Expand Up @@ -340,7 +352,6 @@ def search(
list_size = k_neighbors
return self._index.search(
query=query,
dim=query.shape[0],
knn=k_neighbors,
l_search=list_size,
beam_width=beam_width,
Expand Down Expand Up @@ -409,7 +420,6 @@ def batch_search(
num_queries, dim = queries.shape
return self._index.batch_search(
queries=queries,
dim=dim,
num_queries=num_queries,
knn=k_neighbors,
l_search=list_size,
Expand Down
231 changes: 213 additions & 18 deletions python/src/diskann_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#endif

#include "disk_utils.h"
#include "index.h"
#include "pq_flash_index.h"

PYBIND11_MAKE_OPAQUE(std::vector<unsigned>);
Expand Down Expand Up @@ -98,8 +99,8 @@ template <class T> struct DiskANNIndex
return 0;
}

auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t dim,
const uint64_t knn, const uint64_t l_search, const uint64_t beam_width)
auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
const uint64_t l_search, const uint64_t beam_width)
{
py::array_t<unsigned> ids(knn);
py::array_t<float> dists(knn);
Expand All @@ -118,9 +119,8 @@ template <class T> struct DiskANNIndex
return std::make_pair(ids, dists);
}

auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t dim,
const uint64_t num_queries, const uint64_t knn, const uint64_t l_search,
const uint64_t beam_width, const int num_threads)
auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
const uint64_t knn, const uint64_t l_search, const uint64_t beam_width, const int num_threads)
{
py::array_t<unsigned> ids({num_queries, knn});
py::array_t<float> dists({num_queries, knn});
Expand All @@ -130,7 +130,7 @@ template <class T> struct DiskANNIndex
std::vector<uint64_t> u64_ids(knn * num_queries);

#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < num_queries; i++)
for (int64_t i = 0; i < (int64_t)num_queries; i++)
{
pq_flash_index->cached_beam_search(queries.data(i), knn, l_search, u64_ids.data() + i * knn,
dists.mutable_data(i), beam_width);
Expand All @@ -145,6 +145,132 @@ template <class T> struct DiskANNIndex
}
};

typedef uint32_t IdT;
typedef uint32_t filterT;

template <class T> struct DynamicInMemIndex
{
Index<T, IdT, filterT> *_index;
const IndexWriteParameters write_params;

DynamicInMemIndex(Metric m, const size_t dim, const size_t max_points, const IndexWriteParameters &index_parameters,
const IndexReadParameters &search_parameters, const bool concurrent_consolidate)
: write_params(index_parameters)
{
_index = new Index<T>(m, dim, max_points,
true, // dynamic_index
index_parameters, // used for insert
search_parameters, // used for searching
true, // enable_tags
concurrent_consolidate,
false, // pq_dist_build
0, // num_pq_chunks
false); // use_opq = false
}

~DynamicInMemIndex()
{
delete _index;
}

int insert(py::array_t<T, py::array::c_style | py::array::forcecast> &vector, const IdT id)
{
return _index->insert_point(vector.data(), id);
}

int mark_deleted(const IdT id)
{
return _index->lazy_delete(id);
}

auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
const uint64_t l_search)
{
py::array_t<IdT> ids(knn);
py::array_t<float> dists(knn);
std::vector<T *> empty_vector;
_index->search_with_tags(query.data(), knn, l_search, ids.mutable_data(), dists.mutable_data(), empty_vector);
return std::make_pair(ids, dists);
}

auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
const uint64_t knn, const uint64_t l_search, const int num_threads)
{
py::array_t<unsigned> ids({num_queries, knn});
py::array_t<float> dists({num_queries, knn});
std::vector<T *> empty_vector;

omp_set_num_threads(num_threads);

#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)num_queries; i++)
{
_index->search_with_tags(queries.data(i), knn, l_search, ids.mutable_data(i), dists.mutable_data(i),
empty_vector);
}

return std::make_pair(ids, dists);
}

auto consolidate_delete()
{
return _index->consolidate_deletes(write_params);
}
};

template <class T> struct StaticInMemIndex
{
Index<T, IdT, filterT> *_index;

StaticInMemIndex(Metric m, const std::string &data_path, IndexWriteParameters &index_parameters)
{
size_t ndims, npoints;
diskann::get_bin_metadata(data_path, npoints, ndims);
_index = new Index<T>(m, ndims, npoints,
false, // not a dynamic_index
false, // no enable_tags/ids
false, // no concurrent_consolidate,
false, // pq_dist_build
0, // num_pq_chunks
false, // use_opq = false
0); // num_frozen_pts = 0
_index->build(data_path.c_str(), npoints, index_parameters);
}

~StaticInMemIndex()
{
delete _index;
}

auto search(py::array_t<T, py::array::c_style | py::array::forcecast> &query, const uint64_t knn,
const uint64_t l_search)
{
py::array_t<IdT> ids(knn);
py::array_t<float> dists(knn);
std::vector<T *> empty_vector;
_index->search(query.data(), knn, l_search, ids.mutable_data(), dists.mutable_data());
return std::make_pair(ids, dists);
}

auto batch_search(py::array_t<T, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries,
const uint64_t knn, const uint64_t l_search, const int num_threads)
{
py::array_t<unsigned> ids({num_queries, knn});
py::array_t<float> dists({num_queries, knn});
std::vector<T *> empty_vector;

omp_set_num_threads(num_threads);

#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)num_queries; i++)
{
_index->search(queries.data(i), knn, l_search, ids.mutable_data(i), dists.mutable_data(i));
}

return std::make_pair(ids, dists);
}
};

PYBIND11_MODULE(_diskannpy, m)
{
m.doc() = "DiskANN Python Bindings";
Expand All @@ -159,17 +285,86 @@ PYBIND11_MODULE(_diskannpy, m)
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
.export_values();

py::class_<StaticInMemIndex<float>>(m, "DiskANNStaticInMemFloatIndex")
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
return std::unique_ptr<StaticInMemIndex<float>>(
new StaticInMemIndex<float>(metric, data_path, index_parameters));
}))
.def("search", &StaticInMemIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &StaticInMemIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));

py::class_<StaticInMemIndex<int8_t>>(m, "DiskANNStaticInMemInt8Index")
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
return std::unique_ptr<StaticInMemIndex<int8_t>>(
new StaticInMemIndex<int8_t>(metric, data_path, index_parameters));
}))
.def("search", &StaticInMemIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &StaticInMemIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));

py::class_<StaticInMemIndex<uint8_t>>(m, "DiskANNStaticInMemUint8Index")
.def(py::init([](diskann::Metric metric, const std::string &data_path, IndexWriteParameters &index_parameters) {
return std::unique_ptr<StaticInMemIndex<uint8_t>>(
new StaticInMemIndex<uint8_t>(metric, data_path, index_parameters));
}))
.def("search", &StaticInMemIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &StaticInMemIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"));

py::class_<DynamicInMemIndex<float>>(m, "DiskANNDynamicInMemFloatIndex")
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
const IndexWriteParameters &index_parameters, const IndexReadParameters &search_parameters,
const bool concurrent_consolidate) {
return std::unique_ptr<DynamicInMemIndex<float>>(new DynamicInMemIndex<float>(
metric, dim, max_points, index_parameters, search_parameters, concurrent_consolidate));
}))
.def("search", &DynamicInMemIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &DynamicInMemIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
.def("insert", &DynamicInMemIndex<float>::insert, py::arg("vector"), py::arg("id"))
.def("mark_deleted", &DynamicInMemIndex<float>::mark_deleted, py::arg("id"))
.def("consolidate_delete", &DynamicInMemIndex<float>::consolidate_delete);

py::class_<DynamicInMemIndex<int8_t>>(m, "DiskANNDynamicInMemInt8Index")
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
const IndexWriteParameters &index_parameters, const IndexReadParameters &search_parameters,
const bool concurrent_consolidate) {
return std::unique_ptr<DynamicInMemIndex<int8_t>>(new DynamicInMemIndex<int8_t>(
metric, dim, max_points, index_parameters, search_parameters, concurrent_consolidate));
}))
.def("search", &DynamicInMemIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &DynamicInMemIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
.def("insert", &DynamicInMemIndex<int8_t>::insert, py::arg("vector"), py::arg("id"))
.def("mark_deleted", &DynamicInMemIndex<int8_t>::mark_deleted, py::arg("id"))
.def("consolidate_delete", &DynamicInMemIndex<int8_t>::consolidate_delete);

py::class_<DynamicInMemIndex<uint8_t>>(m, "DiskANNDynamicInMemUint8Index")
.def(py::init([](diskann::Metric metric, const size_t dim, const size_t max_points,
const IndexWriteParameters &index_parameters, const IndexReadParameters &search_parameters,
const bool concurrent_consolidate) {
return std::unique_ptr<DynamicInMemIndex<uint8_t>>(new DynamicInMemIndex<uint8_t>(
metric, dim, max_points, index_parameters, search_parameters, concurrent_consolidate));
}))
.def("search", &DynamicInMemIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"))
.def("batch_search", &DynamicInMemIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("num_threads"))
.def("insert", &DynamicInMemIndex<uint8_t>::insert, py::arg("vector"), py::arg("id"))
.def("mark_deleted", &DynamicInMemIndex<uint8_t>::mark_deleted, py::arg("id"))
.def("consolidate_delete", &DynamicInMemIndex<uint8_t>::consolidate_delete);

py::class_<DiskANNIndex<float>>(m, "DiskANNFloatIndex")
.def(py::init([](diskann::Metric metric) {
return std::unique_ptr<DiskANNIndex<float>>(new DiskANNIndex<float>(metric));
}))
.def("cache_bfs_levels", &DiskANNIndex<float>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
.def("load_index", &DiskANNIndex<float>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
.def("search", &DiskANNIndex<float>::search, py::arg("query"), py::arg("dim"), py::arg("knn"),
py::arg("l_search"), py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<float>::batch_search, py::arg("queries"), py::arg("dim"),
py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def("search", &DiskANNIndex<float>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<float>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<float> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
Expand All @@ -195,10 +390,10 @@ PYBIND11_MODULE(_diskannpy, m)
.def("cache_bfs_levels", &DiskANNIndex<int8_t>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
.def("load_index", &DiskANNIndex<int8_t>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
.def("search", &DiskANNIndex<int8_t>::search, py::arg("query"), py::arg("dim"), py::arg("knn"),
py::arg("l_search"), py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<int8_t>::batch_search, py::arg("queries"), py::arg("dim"),
py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def("search", &DiskANNIndex<int8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<int8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<int8_t> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
Expand All @@ -222,10 +417,10 @@ PYBIND11_MODULE(_diskannpy, m)
.def("cache_bfs_levels", &DiskANNIndex<uint8_t>::cache_bfs_levels, py::arg("num_nodes_to_cache"))
.def("load_index", &DiskANNIndex<uint8_t>::load_index, py::arg("index_path_prefix"), py::arg("num_threads"),
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
.def("search", &DiskANNIndex<uint8_t>::search, py::arg("query"), py::arg("dim"), py::arg("knn"),
py::arg("l_search"), py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("dim"),
py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def("search", &DiskANNIndex<uint8_t>::search, py::arg("query"), py::arg("knn"), py::arg("l_search"),
py::arg("beam_width"))
.def("batch_search", &DiskANNIndex<uint8_t>::batch_search, py::arg("queries"), py::arg("num_queries"),
py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<uint8_t> &self, const char *data_file_path, const char *index_prefix_path, unsigned R,
Expand Down

0 comments on commit 949167c

Please sign in to comment.