Skip to content

Commit

Permalink
improve streamk load balance (NVIDIA#743)
Browse files Browse the repository at this point in the history
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
hwu36 and hwu36 authored Dec 25, 2022
1 parent 78b30d3 commit 1e64f15
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
28 changes: 18 additions & 10 deletions include/cutlass/gemm/device/gemm_universal_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class GemmUniversalBase {
/// Kernel SM occupancy (in thread blocks)
thread_local static int sm_occupancy_;

/// Kernel dynamic shared memory allocation requirement
thread_local static int smem_size_;

/// Initialize static thread-local members for the thread's current device,
/// if necessary.
Expand Down Expand Up @@ -138,15 +140,15 @@ class GemmUniversalBase {
}

// Update the kernel function's shared memory configuration for the current device
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
// Requires more than 48KB: configure for extended, dynamic shared memory
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));

// If requires more than 48KB: configure for extended, dynamic shared memory
if (smem_size_ >= (48 << 10))
{
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
smem_size_);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
Expand All @@ -166,7 +168,7 @@ class GemmUniversalBase {
&sm_occupancy_,
Kernel2<GemmKernel>,
GemmKernel::kThreadCount,
int(sizeof(typename GemmKernel::SharedStorage)),
smem_size_,
cudaOccupancyDisableCachingOverride);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
Expand All @@ -179,7 +181,9 @@ class GemmUniversalBase {
CUTLASS_TRACE_HOST(" "
"device_ordinal: (" << device_ordinal_ << "), "
"device_sms: (" << device_sms_ << "), "
"sm_occupancy: (" << sm_occupancy_ << ")");
"sm_occupancy: (" << sm_occupancy_ << ") "
"smem_size: (" << smem_size_ << ") "
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");

return Status::kSuccess;
}
Expand Down Expand Up @@ -335,17 +339,16 @@ class GemmUniversalBase {
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");

// Configure grid and block dimensions
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
dim3 block(GemmKernel::kThreadCount, 1, 1);
dim3 grid = params_.get_grid_dims();

// Launch kernel
CUTLASS_TRACE_HOST(" "
"grid: (" << grid << "), "
"block: (" << block << "), "
"SMEM: (" << smem_size << ")");
"SMEM: (" << smem_size_ << ")");

Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);

// Query for errors
cudaError_t result = cudaGetLastError();
Expand Down Expand Up @@ -398,6 +401,11 @@ thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;

/// Kernel dynamic shared memory allocation requirement
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;



/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
38 changes: 15 additions & 23 deletions include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
FastDivmod sk_iters_per_big_block;
FastDivmod sk_iters_per_region;
FastDivmod sk_blocks_per_region;
FastDivmod sm_occupancy;

} div_mod;


Expand Down Expand Up @@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
", dp_blocks: " << dp_blocks <<
", sk_blocks_per_region: " << sk_blocks_per_region <<
", sk_regions: " << sk_regions <<
", sk_waves: " << sk_waves <<
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
Expand All @@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
", sm_occupancy: " << sm_occupancy <<
", avail_sms: " << avail_sms <<
", cohort_raster: " << cohort_raster <<
", num_blocks: " << get_num_blocks() <<
"\n\n";
#endif
}
Expand Down Expand Up @@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {

// We're at (or greater) than GPU occupancy

if (full_waves % sm_occupancy == sm_occupancy - 1)
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
{
// Form the SK wave from the partial wave to get us to full GPU occupancy
// If occupancy is more than one CTA per SM, form the SK wave from the partial
// wave to get us to full GPU occupancy
int max_sk_occupancy = 1;

dp_tiles = full_wave_tiles;
Expand Down Expand Up @@ -533,15 +534,13 @@ struct ThreadblockSwizzleStreamK {
dp_first_wave_tiles += waveset_excess;
dp_blocks -= (waveset_excess * avail_sms);
}

}

// Setup fast-div/mod for device-side usage
div_mod.tiled_shape_m = FastDivmod(tiled_shape.m());
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
div_mod.sm_occupancy = FastDivmod(sm_occupancy);
}


Expand Down Expand Up @@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
/// Obtains number of threadblocks per GEMM
int get_num_blocks() const
{
// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;


int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;

if (work_blocks < avail_sms)
if (work_blocks <= avail_sms * 2)
{
return work_blocks;
}

int gpu_occupancy = sm_occupancy * avail_sms;
int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy;
return gpu_wavesets * gpu_occupancy;

return fast_max(work_blocks, avail_sms * 4);
}


Expand Down Expand Up @@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
CUTLASS_DEVICE
int get_block_idx() const
{
int block_idx = RematerializeBlockIdxX();
// Remap the block indices for the first two waves of thread blocks if
// we have multi-occupancy and the grid constitutes four or more waves

int gpu_occupancy = avail_sms * sm_occupancy;
int block_idx = RematerializeBlockIdxX();
int num_blocks = device_num_blocks();
int dest_sm, dest_wave;

div_mod.sm_occupancy(dest_sm, dest_wave, block_idx);

int dest_sm = block_idx / 2;
int dest_wave = block_idx % 2;
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);

// remapping the first gpu_occupancy blocks
if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
if ((sm_occupancy > 1) &&
(num_blocks >= avail_sms * 4) &&
(block_idx < avail_sms * 2))
{
block_idx = remapped_block_idx;
}
Expand Down

0 comments on commit 1e64f15

Please sign in to comment.