Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RBC eps-neighborhood query performance #2211

Merged
merged 13 commits into from
Mar 11, 2024
Prev Previous commit
Next Next commit
utilize reordered index for rbc knn
  • Loading branch information
mfoerste4 committed Mar 5, 2024
commit 0129748a48b4637078e88c14bbff3a76c6088fa3
36 changes: 18 additions & 18 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ template <typename value_idx,
int thread_q = 2,
int tpb = 128,
int col_q = 2>
RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered,
const value_t* X,
const value_int n_cols,
bitset_type* bitset,
Expand Down Expand Up @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index,
// the closest k neighbors, compute it and add to k-select
value_t dist = std::numeric_limits<value_t>::max();
if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -313,7 +313,7 @@ template <typename value_idx = std::int64_t,
int col_q = 2,
typename value_int = std::uint32_t,
typename distance_func>
RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered,
const value_t* X,
value_int n_cols, // n_cols should be 2 or 3 dims
const value_idx* R_knn_inds,
Expand Down Expand Up @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
value_t dist = std::numeric_limits<value_t>::max();

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand All @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index,
value_t dist = std::numeric_limits<value_t>::max();

if (z <= heap.warpKTop) {
const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
value_t local_y_ptr[col_q];
for (value_int j = 0; j < n_cols; ++j) {
local_y_ptr[j] = y_ptr[j];
Expand Down Expand Up @@ -1013,7 +1013,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
if (k <= 32)
block_rbc_kernel_registers<value_idx, value_t, 32, 2, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1033,7 +1033,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 64)
block_rbc_kernel_registers<value_idx, value_t, 64, 3, 128, 2, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1052,7 +1052,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 128)
block_rbc_kernel_registers<value_idx, value_t, 128, 3, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1072,7 +1072,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 256)
block_rbc_kernel_registers<value_idx, value_t, 256, 4, 128, dims, value_int>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1092,7 +1092,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 512)
block_rbc_kernel_registers<value_idx, value_t, 512, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand All @@ -1112,7 +1112,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle,
else if (k <= 1024)
block_rbc_kernel_registers<value_idx, value_t, 1024, 8, 64, dims, value_int>
<<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
R_knn_inds,
Expand Down Expand Up @@ -1182,7 +1182,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1208,7 +1208,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1234,7 +1234,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1260,7 +1260,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
128,
dims>
<<<n_query_rows, 128, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1285,7 +1285,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
8,
64,
dims><<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand All @@ -1310,7 +1310,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle,
8,
64,
dims><<<n_query_rows, 64, 0, resource::get_cuda_stream(handle)>>>(
index.get_X().data_handle(),
index.get_X_reordered().data_handle(),
query,
index.n,
bitset.data(),
Expand Down
Loading