Skip to content

Commit

Permalink
Fix incorrect indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Apr 13, 2022
1 parent b57afa4 commit 631ed31
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ __host__ __device__ constexpr int calc_num_passes()
return ceildiv<int>(sizeof(T) * 8, BitsPerPass);
}

// Minimum reasonable block size for the given radix size.
template <int BitsPerPass>
__host__ __device__ constexpr int calc_min_block_size()
{
return 1 << std::max<int>(BitsPerPass - 4, Pow2<WarpSize>::Log2 + 1);
}

/**
* Bit 0 is the least significant (rightmost);
* this implementation processes input from the most to the least significant bit.
Expand Down Expand Up @@ -447,7 +454,7 @@ __global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf,
* Calculate the minimal batch size, such that GPU is still fully occupied.
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
inline uint16_t get_optimal_batch_size(size_t req_batch_size, size_t blocks_per_row)
inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len)
{
int dev_id, sm_count, occupancy, max_grid_dim_y;
RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
Expand All @@ -456,6 +463,9 @@ inline uint16_t get_optimal_batch_size(size_t req_batch_size, size_t blocks_per_
RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, radix_kernel<T, IdxT, BitsPerPass, BlockSize>, BlockSize, 0));

// number of block we'd use if the batch size is enough to occupy the gpu in any case
size_t blocks_per_row = ceildiv<size_t>(len, BlockSize * ITEM_PER_THREAD);

// fully occupy GPU
size_t opt_batch_size = ceildiv<size_t>(sm_count * occupancy, blocks_per_row);
// round it up to the closest pow-of-two for better data alignment
Expand All @@ -475,9 +485,16 @@ inline uint16_t get_optimal_batch_size(size_t req_batch_size, size_t blocks_per_

// Do not exceed the max grid size.
opt_batch_size = std::min<size_t>(opt_batch_size, size_t(max_grid_dim_y));

// Don't do more work than needed
return uint16_t(std::min<size_t>(opt_batch_size, req_batch_size));
opt_batch_size = std::min<size_t>(opt_batch_size, req_batch_size);
// Let more blocks share one row if the required batch size is too small.
while (opt_batch_size * blocks_per_row < size_t(sm_count * occupancy) &&
// Ensure we still can read data somewhat efficiently
len * sizeof(T) > 2 * VECTORIZED_READ_SIZE * BlockSize * blocks_per_row) {
blocks_per_row <<= 1;
}

return dim3(blocks_per_row, opt_batch_size);
}

/**
Expand Down Expand Up @@ -533,13 +550,20 @@ void radix_topk(const T* in,
bool select_min,
rmm::cuda_stream_view stream)
{
// reduce the block size if the input length is too small.
if constexpr (BlockSize > calc_min_block_size<BitsPerPass>()) {
if (BlockSize * ITEM_PER_THREAD > len) {
return radix_topk<T, IdxT, BitsPerPass, BlockSize / 2>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream);
}
}

// TODO: is it possible to relax this restriction?
static_assert(calc_num_passes<T, BitsPerPass>() > 1);
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();

size_t blocks_per_row = ceildiv<size_t>(len, BlockSize * ITEM_PER_THREAD);
uint16_t max_chunk_size =
get_optimal_batch_size<T, IdxT, BitsPerPass, BlockSize>(batch_size, blocks_per_row);
dim3 blocks = get_optimal_grid_size<T, IdxT, BitsPerPass, BlockSize>(batch_size, len);
size_t max_chunk_size = blocks.y;

rmm::device_uvector<Counter<T, IdxT>> counters(max_chunk_size, stream);
rmm::device_uvector<IdxT> histograms(num_buckets * max_chunk_size, stream);
Expand All @@ -549,7 +573,7 @@ void radix_topk(const T* in,
rmm::device_uvector<IdxT> idx_buf2(len * max_chunk_size, stream);

for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) {
auto chunk_size = uint16_t(std::min<size_t>(max_chunk_size, batch_size - offset));
blocks.y = std::min(max_chunk_size, batch_size - offset);

RAFT_CUDA_TRY(
cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter<T, IdxT>), stream));
Expand All @@ -560,8 +584,6 @@ void radix_topk(const T* in,
T* out_buf = nullptr;
IdxT* out_idx_buf = nullptr;

dim3 blocks(blocks_per_row, chunk_size);

constexpr int num_passes = calc_num_passes<T, BitsPerPass>();

for (int pass = 0; pass < num_passes; ++pass) {
Expand Down

0 comments on commit 631ed31

Please sign in to comment.