Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve RBC eps-neighborhood query performance #2211

Merged
merged 13 commits into from
Mar 11, 2024
Prev Previous commit
Next Next commit
add restrict, free 1 register
  • Loading branch information
mfoerste4 committed Mar 7, 2024
commit 663dabf6bf6f0fd66a8a7d51415a5dac788be302
57 changes: 31 additions & 26 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -719,21 +719,22 @@ template <typename value_idx = std::int64_t,
int dim = 3,
typename value_int = std::uint32_t,
typename distance_func>
RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered,
const value_t* X,
const value_int n_queries,
const value_int n_cols,
const value_t* R,
const value_int m,
const value_t eps,
const value_int n_landmarks,
const value_idx* R_indptr,
const value_idx* R_1nn_cols,
const value_t* R_1nn_dists,
const value_t* R_radius,
distance_func dfunc,
value_idx* adj_ia,
value_idx* adj_ja)
RAFT_KERNEL __launch_bounds__(tpb)
block_rbc_kernel_eps_csr_pass_xd(const value_t* __restrict__ X_reordered,
const value_t* __restrict__ X,
const value_int n_queries,
const value_int n_cols,
const value_t* __restrict__ R,
const value_int m,
const value_t eps,
const value_int n_landmarks,
const value_idx* __restrict__ R_indptr,
const value_idx* __restrict__ R_1nn_cols,
const value_t* __restrict__ R_1nn_dists,
const value_t* __restrict__ R_radius,
distance_func dfunc,
value_idx* __restrict__ adj_ia,
value_idx* adj_ja)
{
constexpr int num_warps = tpb / WarpSize;
constexpr int max_lid = WarpSize - 1;
Expand All @@ -748,10 +749,14 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
// this is an early out for a full warp
if (query_id >= n_queries) return;

value_idx column_index_offset = write_pass ? adj_ia[query_id] : 0;
uint32_t column_index_offset = 0;

// we have no neighbors to fill for this query
if (write_pass && adj_ia[query_id + 1] == column_index_offset) return;
if constexpr (write_pass) {
value_idx offset = adj_ia[query_id];
// we have no neighbors to fill for this query
if (offset == adj_ia[query_id + 1]) return;
adj_ja += offset;
}

const value_t* x_ptr = X + (dim * query_id);
value_t local_x_ptr[dim];
Expand Down Expand Up @@ -812,11 +817,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const value_idx row_pos = column_index_offset + __popc(mask & lid_mask);
adj_ja[row_pos] = index;
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
column_index_offset += __popc(mask);
adj_ja += __popc(mask);
} else {
column_index_offset += (in_range);
}
Expand All @@ -833,11 +838,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const value_idx row_pos = column_index_offset + __popc(mask & lid_mask);
adj_ja[row_pos] = index;
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
column_index_offset += __popc(mask);
adj_ja += __popc(mask);
} else {
column_index_offset += (in_range);
}
Expand Down