Skip to content

Commit

Permalink
Fix incorrect kernel launcher logic w.r.t. indexing type and attempt …
Browse files Browse the repository at this point in the history
…to improve its logic
  • Loading branch information
achirkin committed Apr 4, 2022
1 parent 55c1d00 commit dba9c57
Showing 1 changed file with 66 additions and 39 deletions.
105 changes: 66 additions & 39 deletions cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
};

template <typename T, typename IdxT>
int calc_smem_size_for_block_wide(int num_of_warp, IdxT k)
auto calc_smem_size_for_block_wide(int num_of_warp, int k) -> int
{
return Pow2<256>::roundUp(num_of_warp / 2 * sizeof(T) * k) + num_of_warp / 2 * sizeof(IdxT) * k;
}
Expand Down Expand Up @@ -536,23 +536,32 @@ struct launch_setup {
* Returned block size
* @param[out] min_grid_size
* Returned minimum grid size needed to achieve the best potential occupancy
* @param[in] block_size_limit
* Forcefully limit the block size (optional)
*/
static void calc_optimal_params(int k, int* block_size, int* min_grid_size)
static void calc_optimal_params(int k,
int* block_size,
int* min_grid_size,
int block_size_limit = 0)
{
const int capacity = calc_capacity(k);
if constexpr (Capacity > 1) {
if (capacity < Capacity) {
return launch_setup<WarpSortClass, T, IdxT, Capacity / 2>::calc_optimal_params(
capacity, block_size, min_grid_size);
capacity, block_size, min_grid_size, block_size_limit);
}
}
ASSERT(capacity <= Capacity, "Requested k is too big (%d)", k);
auto calc_smem = [k](int block_size) {
int num_of_warp = block_size / WarpSize;
return calc_smem_size_for_block_wide<T>(num_of_warp, k);
int num_of_warp = block_size / std::min<int>(WarpSize, Capacity);
return calc_smem_size_for_block_wide<T, IdxT>(num_of_warp, k);
};
RAFT_CUDA_TRY(cudaOccupancyMaxPotentialBlockSizeVariableSMem(
min_grid_size, block_size, block_kernel<WarpSortClass, Capacity, true, T, IdxT>, calc_smem));
min_grid_size,
block_size,
block_kernel<WarpSortClass, Capacity, true, T, IdxT>,
calc_smem,
block_size_limit));
}

static void kernel(int k,
Expand Down Expand Up @@ -633,57 +642,75 @@ struct LaunchThreshold<warp_sort_immediate> {
};

template <template <int, bool, typename, typename> class WarpSortClass, typename T, typename IdxT>
void calc_launch_parameter(int batch_size, IdxT len, int k, int* p_num_of_block, int* p_num_of_warp)
void calc_launch_parameter(
size_t batch_size, size_t len, int k, int* p_num_of_block, int* p_num_of_warp)
{
const int capacity = calc_capacity(k);
const int capacity_per_full_warp = std::max(capacity, WarpSize);
int block_size = 0;
int min_grid_size = 0;
launch_setup<WarpSortClass, T, IdxT>::calc_optimal_params(k, &block_size, &min_grid_size);
block_size = Pow2<WarpSize>::roundDown(block_size);

int num_of_warp;
int num_of_block;
if (batch_size < min_grid_size) { // may use multiple blocks
if (batch_size < size_t(min_grid_size)) { // may use multiple blocks
num_of_warp = block_size / WarpSize;
num_of_block = min_grid_size / batch_size;
int len_per_block = (len - 1) / num_of_block + 1;
int len_per_warp = (len_per_block - 1) / num_of_warp + 1;
num_of_block = min_grid_size / int(batch_size);
int len_per_block = int(ceildiv<size_t>(len, num_of_block));
int len_per_warp = ceildiv(len_per_block, num_of_warp);

len_per_warp = Pow2<WarpSize>::roundUp(len_per_warp);
len_per_block = len_per_warp * num_of_warp;
num_of_block = (len - 1) / len_per_block + 1;
num_of_block = int(ceildiv<size_t>(len, len_per_block));

constexpr int len_factor = LaunchThreshold<WarpSortClass>::len_factor_for_multi_block;
if (len_per_warp < capacity_per_full_warp * len_factor) {
len_per_warp = capacity_per_full_warp * len_factor;
constexpr int kLenFactor = LaunchThreshold<WarpSortClass>::len_factor_for_multi_block;
if (len_per_warp < capacity_per_full_warp * kLenFactor) {
len_per_warp = capacity_per_full_warp * kLenFactor;
len_per_block = num_of_warp * len_per_warp;
if ((IdxT)len_per_block > len) { len_per_block = len; }
num_of_block = (len - 1) / len_per_block + 1;
num_of_warp = (len_per_block - 1) / len_per_warp + 1;
if (size_t(len_per_block) > len) { len_per_block = len; }
num_of_block = int(ceildiv<size_t>(len, len_per_block));
num_of_warp = ceildiv(len_per_block, len_per_warp);
}
} else { // use only single block
num_of_block = 1;

// block size could be decreased if batch size is large
float scale = batch_size / min_grid_size;
if (scale > 1) {
// make sure scale > 1 so block_size only decreases not increases
if (0.8 * scale > 1) { scale = 0.8 * scale; }
block_size /= scale;
if (block_size < 1) { block_size = 1; }
block_size = Pow2<WarpSize>::roundUp(block_size);
}
auto adjust_block_size = [len, capacity_per_full_warp](int bs) {
int warps_per_block = bs / WarpSize;
int len_per_warp = int(ceildiv<size_t>(len, warps_per_block));
len_per_warp = Pow2<WarpSize>::roundUp(len_per_warp);
warps_per_block = int(ceildiv<size_t>(len, len_per_warp));

num_of_warp = block_size / WarpSize;
int len_per_warp = (len - 1) / num_of_warp + 1;
len_per_warp = Pow2<WarpSize>::roundUp(len_per_warp);
num_of_warp = (len - 1) / len_per_warp + 1;
constexpr int kLenFactor = LaunchThreshold<WarpSortClass>::len_factor_for_single_block;
if (len_per_warp < capacity_per_full_warp * kLenFactor) {
len_per_warp = capacity_per_full_warp * kLenFactor;
warps_per_block = int(ceildiv<size_t>(len, len_per_warp));
}

constexpr int len_factor = LaunchThreshold<WarpSortClass>::len_factor_for_single_block;
if (len_per_warp < capacity_per_full_warp * len_factor) {
len_per_warp = capacity_per_full_warp * len_factor;
num_of_warp = (len - 1) / len_per_warp + 1;
}
return warps_per_block * WarpSize;
};

// gradually reduce the block size while the batch size allows and the len is not big enough
// to occupy a single block well.
block_size = adjust_block_size(block_size);
do {
num_of_warp = block_size / WarpSize;
int another_block_size = 0, another_min_grid_size = 0;
launch_setup<WarpSortClass, T, IdxT>::calc_optimal_params(
k, &another_block_size, &another_min_grid_size, block_size);
another_block_size = adjust_block_size(another_block_size);
if (batch_size >= size_t(another_min_grid_size) // still have enough work
&& another_block_size < block_size // protect against an infinite loop
&& another_min_grid_size * another_block_size >
min_grid_size * block_size // improve occupancy
) {
block_size = another_block_size;
min_grid_size = another_min_grid_size;
} else {
break;
}
} while (block_size > WarpSize);
num_of_warp = std::max(1, num_of_warp);
}

*p_num_of_block = num_of_block;
Expand Down Expand Up @@ -712,7 +739,7 @@ void warp_sort_topk_(int num_of_block,
T* result_val = (num_of_block == 1) ? out : tmp_val.data();
IdxT* result_idx = (num_of_block == 1) ? out_idx : tmp_idx.data();
int block_dim = num_of_warp * warp_width;
int smem_size = calc_smem_size_for_block_wide<T>(num_of_warp, (IdxT)k);
int smem_size = calc_smem_size_for_block_wide<T, IdxT>(num_of_warp, k);

launch_setup<WarpSortClass, T, IdxT>::kernel((IdxT)k,
select_min,
Expand Down Expand Up @@ -794,8 +821,8 @@ void warp_sort_topk(const T* in,
int capacity = calc_capacity(k);
int num_of_block = 0;
int num_of_warp = 0;
calc_launch_parameter<warp_sort_immediate, T>(
batch_size, len, (IdxT)k, &num_of_block, &num_of_warp);
calc_launch_parameter<warp_sort_immediate, T, IdxT>(
batch_size, len, k, &num_of_block, &num_of_warp);
int len_per_warp = len / (num_of_block * num_of_warp);

if (len_per_warp <= capacity * LaunchThreshold<warp_sort_immediate>::len_factor_for_choosing) {
Expand Down

0 comments on commit dba9c57

Please sign in to comment.