Skip to content

Commit

Permalink
improved device_resources_snmg
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Nov 15, 2024
1 parent 3a99b40 commit e16b68e
Show file tree
Hide file tree
Showing 17 changed files with 712 additions and 725 deletions.
2 changes: 2 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {
{
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);

clique_.set_memory_pool(80);
}

void build(const T* dataset, size_t nrow) final;
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class cuvs_mg_ivf_flat : public algo<T>, public algo_gpu {
: algo<T>(metric, dim), index_params_(param), clique_()
{
index_params_.metric = parse_metric_type(metric);

clique_.set_memory_pool(80);
}

void build(const T* dataset, size_t nrow) final;
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class cuvs_mg_ivf_pq : public algo<T>, public algo_gpu {
: algo<T>(metric, dim), index_params_(param), clique_()
{
index_params_.metric = parse_metric_type(metric);

clique_.set_memory_pool(80);
}

void build(const T* dataset, size_t nrow) final;
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/neighbors/mg/generate_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
const mg::index_params<ivf_flat::index_params>& index_params, \\
raft::host_matrix_view<const T, int64_t, row_major> index_dataset) \\
{ \\
index<ivf_flat::index<T, IdxT>, T, IdxT> index(index_params.mode, clique.num_ranks_); \\
index<ivf_flat::index<T, IdxT>, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::build(clique, index, \\
static_cast<const cuvs::neighbors::index_params*>(&index_params), \\
index_dataset); \\
Expand Down Expand Up @@ -104,7 +104,7 @@
index<ivf_flat::index<T, IdxT>, T, IdxT> distribute_flat<T, IdxT>(const raft::device_resources_snmg& clique, \\
const std::string& filename) \\
{ \\
auto idx = index<ivf_flat::index<T, IdxT>, T, IdxT>(REPLICATED, clique.num_ranks_); \\
auto idx = index<ivf_flat::index<T, IdxT>, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\
return idx; \\
}
Expand All @@ -116,7 +116,7 @@
const mg::index_params<ivf_pq::index_params>& index_params, \\
raft::host_matrix_view<const T, int64_t, row_major> index_dataset) \\
{ \\
index<ivf_pq::index<IdxT>, T, IdxT> index(index_params.mode, clique.num_ranks_); \\
index<ivf_pq::index<IdxT>, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::build(clique, index, \\
static_cast<const cuvs::neighbors::index_params*>(&index_params), \\
index_dataset); \\
Expand Down Expand Up @@ -163,7 +163,7 @@
index<ivf_pq::index<IdxT>, T, IdxT> distribute_pq<T, IdxT>(const raft::device_resources_snmg& clique, \\
const std::string& filename) \\
{ \\
auto idx = index<ivf_pq::index<IdxT>, T, IdxT>(REPLICATED, clique.num_ranks_); \\
auto idx = index<ivf_pq::index<IdxT>, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\
return idx; \\
}
Expand All @@ -175,7 +175,7 @@
const mg::index_params<cagra::index_params>& index_params, \\
raft::host_matrix_view<const T, int64_t, row_major> index_dataset) \\
{ \\
index<cagra::index<T, IdxT>, T, IdxT> index(index_params.mode, clique.num_ranks_); \\
index<cagra::index<T, IdxT>, T, IdxT> index(index_params.mode, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::build(clique, index, \\
static_cast<const cuvs::neighbors::index_params*>(&index_params), \\
index_dataset); \\
Expand Down Expand Up @@ -214,7 +214,7 @@
index<cagra::index<T, IdxT>, T, IdxT> distribute_cagra<T, IdxT>(const raft::device_resources_snmg& clique, \\
const std::string& filename) \\
{ \\
auto idx = index<cagra::index<T, IdxT>, T, IdxT>(REPLICATED, clique.num_ranks_); \\
auto idx = index<cagra::index<T, IdxT>, T, IdxT>(REPLICATED, clique.get_num_ranks()); \\
cuvs::neighbors::mg::detail::deserialize_and_distribute(clique, idx, filename); \\
return idx; \\
}
Expand Down
96 changes: 38 additions & 58 deletions cpp/src/neighbors/mg/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ void deserialize_and_distribute(const raft::device_resources_snmg& clique,
const std::string& filename)
{
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
auto& ann_if = index.ann_interfaces_.emplace_back();
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_.emplace_back();
cuvs::neighbors::deserialize(dev_res, ann_if, filename);
}
}
Expand All @@ -72,17 +70,15 @@ void deserialize(const raft::device_resources_snmg& clique,
index.mode_ = (cuvs::neighbors::mg::distribution_mode)deserialize_scalar<int>(handle, is);
index.num_ranks_ = deserialize_scalar<int>(handle, is);

if (index.num_ranks_ != clique.num_ranks_) {
if (index.num_ranks_ != clique.get_num_ranks()) {
RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks",
index.num_ranks_,
clique.num_ranks_);
clique.get_num_ranks());
}

for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
auto& ann_if = index.ann_interfaces_.emplace_back();
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_.emplace_back();
cuvs::neighbors::deserialize(dev_res, ann_if, is);
}

Expand All @@ -102,10 +98,8 @@ void build(const raft::device_resources_snmg& clique,
index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
auto& ann_if = index.ann_interfaces_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset);
resource::sync_stream(dev_res);
}
Expand All @@ -119,13 +113,11 @@ void build(const raft::device_resources_snmg& clique,
index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols);
auto partition = raft::make_host_matrix_view<const T, int64_t, row_major>(
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols);
auto partition = raft::make_host_matrix_view<const T, int64_t, row_major>(
partition_ptr, n_rows_of_current_shard, n_cols);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::build(dev_res, ann_if, index_params, partition);
Expand All @@ -146,10 +138,8 @@ void extend(const raft::device_resources_snmg& clique,

#pragma omp parallel for
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
auto& ann_if = index.ann_interfaces_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::extend(dev_res, ann_if, new_vectors, new_indices);
resource::sync_stream(dev_res);
}
Expand All @@ -161,13 +151,11 @@ void extend(const raft::device_resources_snmg& clique,

#pragma omp parallel for
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols);
auto new_vectors_part = raft::make_host_matrix_view<const T, int64_t, row_major>(
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols);
auto new_vectors_part = raft::make_host_matrix_view<const T, int64_t, row_major>(
new_vectors_ptr, n_rows_of_current_shard, n_cols);

std::optional<raft::host_vector_view<const IdxT, int64_t>> new_indices_part = std::nullopt;
Expand Down Expand Up @@ -219,13 +207,11 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique,
check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang
#pragma omp parallel for num_threads(index.num_ranks_)
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));

if (rank == clique.root_rank_) { // root rank
uint64_t batch_offset = clique.root_rank_ * part_size;
if (rank == clique.get_root_rank()) { // root rank
uint64_t batch_offset = clique.get_root_rank() * part_size;
auto d_neighbors = raft::make_device_matrix_view<IdxT, int64_t, row_major>(
in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors);
auto d_distances = raft::make_device_matrix_view<float, int64_t, row_major>(
Expand All @@ -236,20 +222,20 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique,
// wait for other ranks
ncclGroupStart();
for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) {
if (from_rank == clique.root_rank_) continue;
if (from_rank == clique.get_root_rank()) continue;

batch_offset = from_rank * part_size;
ncclRecv(in_neighbors.data_handle() + batch_offset,
part_size * sizeof(IdxT),
ncclUint8,
from_rank,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
ncclRecv(in_distances.data_handle() + batch_offset,
part_size * sizeof(float),
ncclUint8,
from_rank,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
}
ncclGroupEnd();
Expand All @@ -267,14 +253,14 @@ void sharded_search_with_direct_merge(const raft::device_resources_snmg& clique,
ncclSend(d_neighbors.data_handle(),
part_size * sizeof(IdxT),
ncclUint8,
clique.root_rank_,
clique.nccl_comms_[rank],
clique.get_root_rank(),
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
ncclSend(d_distances.data_handle(),
part_size * sizeof(float),
ncclUint8,
clique.root_rank_,
clique.nccl_comms_[rank],
clique.get_root_rank(),
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
ncclGroupEnd();
resource::sync_stream(dev_res);
Expand Down Expand Up @@ -342,10 +328,8 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique,
check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang
#pragma omp parallel for num_threads(index.num_ranks_)
for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));

int64_t part_size = n_rows_of_current_batch * n_neighbors;

Expand Down Expand Up @@ -390,13 +374,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique,
part_size * sizeof(IdxT),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
ncclRecv(tmp_distances.data_handle() + part_size,
part_size * sizeof(float),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
received_something = true;
}
Expand All @@ -407,13 +391,13 @@ void sharded_search_with_tree_merge(const raft::device_resources_snmg& clique,
part_size * sizeof(IdxT),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
ncclSend(tmp_distances.data_handle(),
part_size * sizeof(float),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
clique.get_nccl_comm(rank),
resource::get_cuda_stream(dev_res));
}
ncclGroupEnd();
Expand Down Expand Up @@ -466,9 +450,7 @@ void run_search_batch(const raft::device_resources_snmg& clique,
int64_t n_cols,
int64_t n_neighbors)
{
int dev_id = clique.device_ids_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
const raft::device_resources& dev_res = clique.device_resources_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];

auto query_partition = raft::make_host_matrix_view<const T, int64_t, row_major>(
Expand Down Expand Up @@ -645,10 +627,8 @@ void serialize(const raft::device_resources_snmg& clique,
serialize_scalar(handle, of, index.num_ranks_);

for (int rank = 0; rank < index.num_ranks_; rank++) {
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
RAFT_CUDA_TRY(cudaSetDevice(dev_id));
auto& ann_if = index.ann_interfaces_[rank];
const raft::device_resources& dev_res = clique.set_current_device_to_rank(rank);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::serialize(dev_res, ann_if, of);
}

Expand Down
Loading

0 comments on commit e16b68e

Please sign in to comment.