diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h index f5a394482..27a0fd7ac 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h @@ -45,6 +45,8 @@ class cuvs_mg_cagra : public algo, public algo_gpu { { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h index 05e68b26b..5e811da33 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_flat : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h index d430d27bf..c4a820cad 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h @@ -40,6 +40,8 @@ class cuvs_mg_ivf_pq : public algo, public algo_gpu { : algo(metric, dim), index_params_(param), clique_() { index_params_.metric = parse_metric_type(metric); + + clique_.set_memory_pool(80); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/src/neighbors/mg/generate_mg.py b/cpp/src/neighbors/mg/generate_mg.py index 023f5baf3..26e81da16 100644 --- a/cpp/src/neighbors/mg/generate_mg.py +++ b/cpp/src/neighbors/mg/generate_mg.py @@ -57,7 +57,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -104,7 +104,7 @@ index, T, IdxT> distribute_flat(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -116,7 +116,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -163,7 +163,7 @@ index, T, IdxT> distribute_pq(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } @@ -175,7 +175,7 @@ const mg::index_params& index_params, \\ raft::host_matrix_view index_dataset) \\ { \\ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \\ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::build(clique, index, \\ static_cast(&index_params), \\ index_dataset); \\ @@ -214,7 +214,7 @@ index, T, IdxT> distribute_cagra(const raft::device_resources_snmg& clique, \\ const std::string& filename) \\ { \\ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \\ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\ cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\ return idx; \\ } diff --git a/cpp/src/neighbors/mg/mg.cuh b/cpp/src/neighbors/mg/mg.cuh index 0e113ef72..c6812b1e1 100644 --- a/cpp/src/neighbors/mg/mg.cuh +++ b/cpp/src/neighbors/mg/mg.cuh @@ -51,10 +51,8 @@ void deserialize_and_distribute(const raft::device_resources_snmg& clique, const std::string& filename) { for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, filename); } } @@ -72,17 +70,15 @@ void deserialize(const raft::device_resources_snmg& clique, index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar(handle, is); index.num_ranks_ = deserialize_scalar(handle, is); - if (index.num_ranks_ != clique.num_ranks_) { + if (index.num_ranks_ != clique.get_num_ranks()) { RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks", index.num_ranks_, - clique.num_ranks_); + clique.get_num_ranks()); } for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_.emplace_back(); + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_.emplace_back(); cuvs::neighbors::deserialize(dev_res, ann_if, is); } @@ -102,10 +98,8 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset); resource::sync_stream(dev_res); } @@ -119,13 +113,11 @@ void build(const raft::device_resources_snmg& clique, index.ann_interfaces_.resize(index.num_ranks_); #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::build(dev_res, ann_if, index_params, partition); @@ -146,10 +138,8 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices); resource::sync_stream(dev_res); } @@ -161,13 +151,11 @@ void extend(const raft::device_resources_snmg& clique, #pragma omp parallel for for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - int64_t offset = rank * n_rows_per_shard; - int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + int64_t offset = rank * n_rows_per_shard; + int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; @@ -219,13 +207,11 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank - uint64_t batch_offset = clique.root_rank_ * part_size; + if (rank == clique.get_root_rank()) { // root rank + uint64_t batch_offset = clique.get_root_rank() * part_size; auto d_neighbors = raft::make_device_matrix_view( in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); auto d_distances = raft::make_device_matrix_view( @@ -236,20 +222,20 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, // wait for other ranks ncclGroupStart(); for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) continue; + if (from_rank == clique.get_root_rank()) continue; batch_offset = from_rank * part_size; ncclRecv(in_neighbors.data_handle() + batch_offset, part_size * sizeof(IdxT), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(in_distances.data_handle() + batch_offset, part_size * sizeof(float), ncclUint8, from_rank, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -267,14 +253,14 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique, ncclSend(d_neighbors.data_handle(), part_size * sizeof(IdxT), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(d_distances.data_handle(), part_size * sizeof(float), ncclUint8, - clique.root_rank_, - clique.nccl_comms_[rank], + clique.get_root_rank(), + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclGroupEnd(); resource::sync_stream(dev_res); @@ -342,10 +328,8 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(index.num_ranks_) for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); int64_t part_size = n_rows_of_current_batch * n_neighbors; @@ -390,13 +374,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclRecv(tmp_distances.data_handle() + part_size, part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); received_something = true; } @@ -407,13 +391,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique, part_size * sizeof(IdxT), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); ncclSend(tmp_distances.data_handle(), part_size * sizeof(float), ncclUint8, other_id, - clique.nccl_comms_[rank], + clique.get_nccl_comm(rank), resource::get_cuda_stream(dev_res)); } ncclGroupEnd(); @@ -466,9 +450,7 @@ void run_search_batch(const raft::device_resources_snmg& clique, int64_t n_cols, int64_t n_neighbors) { - int dev_id = clique.device_ids_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - const raft::device_resources& dev_res = clique.device_resources_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); auto& ann_if = index.ann_interfaces_[rank]; auto query_partition = raft::make_host_matrix_view( @@ -645,10 +627,8 @@ void serialize(const raft::device_resources_snmg& clique, serialize_scalar(handle, of, index.num_ranks_); for (int rank = 0; rank < index.num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = index.ann_interfaces_[rank]; + const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank); + auto& ann_if = index.ann_interfaces_[rank]; cuvs::neighbors::serialize(dev_res, ann_if, of); } diff --git a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu index c3ef3705e..e179a56e3 100644 --- a/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_float_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(float, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu index ea9ec672b..3e369d9ac 100644 --- a/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_half_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(half, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu index aeae0f2cc..5ebf223d1 100644 --- a/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(int8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu index 22421d6f0..923031b1c 100644 --- a/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu +++ b/cpp/src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu @@ -27,61 +27,61 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_CAGRA(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_cagra( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_CAGRA(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_cagra( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_CAGRA(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu index 423aa0284..f90f6fcfb 100644 --- a/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu index 06bb7af26..2eefad5d5 100644 --- a/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu index bbf7d96f8..9684f19d8 100644 --- a/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_flat_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_FLAT(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_flat( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_FLAT(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_flat( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_FLAT(uint8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu index 441a09e2f..c71133ac4 100644 --- a/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_float_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(float, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu index bf6126fee..df148620f 100644 --- a/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_half_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(half, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu index 3921f810c..afe5faa41 100644 --- a/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_int8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(int8_t, int64_t); diff --git a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu index 8f4683fd7..c725d2139 100644 --- a/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/mg/mg_pq_uint8_t_int64_t.cu @@ -27,69 +27,69 @@ namespace cuvs::neighbors::mg { -#define CUVS_INST_MG_PQ(T, IdxT) \ - index, T, IdxT> build( \ - const raft::device_resources_snmg& clique, \ - const mg::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - { \ - index, T, IdxT> index(index_params.mode, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::build( \ - clique, \ - index, \ - static_cast(&index_params), \ - index_dataset); \ - return index; \ - } \ - \ - void extend(const raft::device_resources_snmg& clique, \ - index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ - } \ - \ - void search(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const mg::search_params& search_params, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances, \ - int64_t n_rows_per_batch) \ - { \ - cuvs::neighbors::mg::detail::search( \ - clique, \ - index, \ - static_cast(&search_params), \ - queries, \ - neighbors, \ - distances, \ - n_rows_per_batch); \ - } \ - \ - void serialize(const raft::device_resources_snmg& clique, \ - const index, T, IdxT>& index, \ - const std::string& filename) \ - { \ - cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ - } \ - \ - template <> \ - index, T, IdxT> deserialize_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(clique, filename); \ - return idx; \ - } \ - \ - template <> \ - index, T, IdxT> distribute_pq( \ - const raft::device_resources_snmg& clique, const std::string& filename) \ - { \ - auto idx = index, T, IdxT>(REPLICATED, clique.num_ranks_); \ - cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ - return idx; \ +#define CUVS_INST_MG_PQ(T, IdxT) \ + index, T, IdxT> build( \ + const raft::device_resources_snmg& clique, \ + const mg::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + { \ + index, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::build( \ + clique, \ + index, \ + static_cast(&index_params), \ + index_dataset); \ + return index; \ + } \ + \ + void extend(const raft::device_resources_snmg& clique, \ + index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices) \ + { \ + cuvs::neighbors::mg::detail::extend(clique, index, new_vectors, new_indices); \ + } \ + \ + void search(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const mg::search_params& search_params, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances, \ + int64_t n_rows_per_batch) \ + { \ + cuvs::neighbors::mg::detail::search( \ + clique, \ + index, \ + static_cast(&search_params), \ + queries, \ + neighbors, \ + distances, \ + n_rows_per_batch); \ + } \ + \ + void serialize(const raft::device_resources_snmg& clique, \ + const index, T, IdxT>& index, \ + const std::string& filename) \ + { \ + cuvs::neighbors::mg::detail::serialize(clique, index, filename); \ + } \ + \ + template <> \ + index, T, IdxT> deserialize_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(clique, filename); \ + return idx; \ + } \ + \ + template <> \ + index, T, IdxT> distribute_pq( \ + const raft::device_resources_snmg& clique, const std::string& filename) \ + { \ + auto idx = index, T, IdxT>(REPLICATED, clique.get_num_ranks()); \ + cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \ + return idx; \ } CUVS_INST_MG_PQ(uint8_t, int64_t); diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index eb97b583c..f634765c9 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -54,6 +54,7 @@ class AnnMGTest : public ::testing::TestWithParam { h_index_dataset(0), h_queries(0) { + clique_.set_memory_pool(80); } void testAnnMG()