Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of select-top-k WARP_SORT implementation #606

Merged
merged 13 commits into from
May 16, 2022
Prev Previous commit
Next Next commit
Try to merge less often - when the buffer overflows for some threads.
  • Loading branch information
achirkin committed Apr 1, 2022
commit 9595ad2fd9471df3a4f028d17c49e01c5ba7f683
52 changes: 26 additions & 26 deletions cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,18 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
{
// comparing for k_th should reduce the total amount of updates:
// `false` means the input value is surely not in the top-k values.
if (is_ordered<Ascending>(val, k_th_)) {
// NB: the loop is used here to ensure the constant indexing,
// to not force the buffers spill into the local memory.
#pragma unroll
for (int i = 0; i < kMaxBufLen; i++) {
if (i == buf_len_) {
val_buf_[i] = val;
idx_buf_[i] = idx;
}
bool do_add = is_ordered<Ascending>(val, k_th_);
// merge the buf if it's full and we cannot add an element anymore.
if (any(buf_len_ + do_add > kMaxBufLen)) {
// still, add an element before merging if possible for this thread
if (do_add && buf_len_ < kMaxBufLen) {
add_to_buf_(val, idx);
do_add = false;
}
++buf_len_;
merge_buf_();
}
if (any(buf_len_ == kMaxBufLen)) { merge_buf_(); }
// add an element if necessary and haven't already.
if (do_add) { add_to_buf_(val, idx); }
}

__device__ void done()
Expand All @@ -345,25 +344,14 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
}

private:
__device__ void set_k_th_()
__device__ __forceinline__ void set_k_th_()
{
// NB on using srcLane: it's ok if it is outside the warp size / width;
// the modulo op will be done inside the __shfl_sync.
k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth);

// synchronize between subwarps to find the earliest k_th_ among them.
if constexpr (kWarpWidth < WarpSize) {
if (k_th_share_counter_-- == 0) {
#pragma unroll
for (int width = kWarpWidth << 1; width <= WarpSize; width <<= 1) {
k_th_ = select_first<Ascending>(k_th_, shfl_xor(k_th_, width - 1));
}
k_th_share_counter_ = kThShareInterval;
}
}
}

__device__ void merge_buf_()
__device__ __forceinline__ void merge_buf_()
{
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_);
Expand All @@ -375,6 +363,20 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
}
}

__device__ __forceinline__ void add_to_buf_(T val, IdxT idx)
{
// NB: the loop is used here to ensure the constant indexing,
// to not force the buffers spill into the local memory.
#pragma unroll
for (int i = 0; i < kMaxBufLen; i++) {
if (i == buf_len_) {
val_buf_[i] = val;
idx_buf_[i] = idx;
}
}
buf_len_++;
}

using warp_sort<Capacity, Ascending, T, IdxT>::kMaxArrLen;
using warp_sort<Capacity, Ascending, T, IdxT>::val_arr_;
using warp_sort<Capacity, Ascending, T, IdxT>::idx_arr_;
Expand All @@ -386,8 +388,6 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
int buf_len_;

T k_th_;
int k_th_share_counter_ = 0;
static constexpr int kThShareInterval = 1 + WarpSize / kWarpWidth;
};

/**
Expand Down