diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 14b640fa55..8f4d1b091b 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -107,6 +107,8 @@ class GemmUniversalBase { /// Kernel SM occupancy (in thread blocks) thread_local static int sm_occupancy_; + /// Kernel dynamic shared memory allocation requirement + thread_local static int smem_size_; /// Initialize static thread-local members for the thread's current device, /// if necessary. @@ -138,15 +140,15 @@ class GemmUniversalBase { } // Update the kernel function's shared memory configuration for the current device - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - if (smem_size >= (48 << 10)) - { - // Requires more than 48KB: configure for extended, dynamic shared memory + smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); + // If requires more than 48KB: configure for extended, dynamic shared memory + if (smem_size_ >= (48 << 10)) + { cudart_result = cudaFuncSetAttribute( Kernel2, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + smem_size_); if (cudart_result != cudaSuccess) { CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; @@ -166,7 +168,7 @@ class GemmUniversalBase { &sm_occupancy_, Kernel2, GemmKernel::kThreadCount, - int(sizeof(typename GemmKernel::SharedStorage)), + smem_size_, cudaOccupancyDisableCachingOverride); if (cudart_result != cudaSuccess) { CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); @@ -179,7 +181,9 @@ class GemmUniversalBase { CUTLASS_TRACE_HOST(" " "device_ordinal: (" << device_ordinal_ << "), " "device_sms: (" << device_sms_ << "), " - "sm_occupancy: (" << sm_occupancy_ << ")"); + "sm_occupancy: (" << sm_occupancy_ << ") " + "smem_size: (" << smem_size_ << ") " + "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); return Status::kSuccess; } @@ -335,7 +339,6 @@ class GemmUniversalBase { CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); // Configure grid and block dimensions - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); dim3 block(GemmKernel::kThreadCount, 1, 1); dim3 grid = params_.get_grid_dims(); @@ -343,9 +346,9 @@ class GemmUniversalBase { CUTLASS_TRACE_HOST(" " "grid: (" << grid << "), " "block: (" << block << "), " - "SMEM: (" << smem_size << ")"); + "SMEM: (" << smem_size_ << ")"); - Kernel2<<>>(params_); + Kernel2<<>>(params_); // Query for errors cudaError_t result = cudaGetLastError(); @@ -398,6 +401,11 @@ thread_local int GemmUniversalBase::device_sms_ = -1; template thread_local int GemmUniversalBase::sm_occupancy_ = -1; +/// Kernel dynamic shared memory allocation requirement +template +thread_local int GemmUniversalBase::smem_size_ = -1; + + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index e31d7c7f26..499157ea65 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK { FastDivmod sk_iters_per_big_block; FastDivmod sk_iters_per_region; FastDivmod sk_blocks_per_region; - FastDivmod sm_occupancy; - } div_mod; @@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK { ", dp_blocks: " << dp_blocks << ", sk_blocks_per_region: " << sk_blocks_per_region << ", sk_regions: " << sk_regions << + ", sk_waves: " << sk_waves << ", sk_iters_per_normal_block: " << sk_iters_per_normal_block << ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << ", dp_first_wave_tiles: " << dp_first_wave_tiles << @@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK { ", sm_occupancy: " << sm_occupancy << ", avail_sms: " << avail_sms << ", cohort_raster: " << cohort_raster << + ", num_blocks: " << get_num_blocks() << "\n\n"; #endif } @@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK { // We're at (or greater) than GPU occupancy - if (full_waves % sm_occupancy == sm_occupancy - 1) + if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1)) { - // Form the SK wave from the partial wave to get us to full GPU occupancy + // If occupancy is more than one CTA per SM, form the SK wave from the partial + // wave to get us to full GPU occupancy int max_sk_occupancy = 1; dp_tiles = full_wave_tiles; @@ -533,7 +534,6 @@ struct ThreadblockSwizzleStreamK { dp_first_wave_tiles += waveset_excess; dp_blocks -= (waveset_excess * avail_sms); } - } // Setup fast-div/mod for device-side usage @@ -541,7 +541,6 @@ struct ThreadblockSwizzleStreamK { div_mod.tiled_shape_n = FastDivmod(tiled_shape.n()); div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); div_mod.iters_per_tile = FastDivmod(iters_per_tile); - div_mod.sm_occupancy = FastDivmod(sm_occupancy); } @@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK { /// Obtains number of threadblocks per GEMM int get_num_blocks() const { -// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms; -// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks; - - int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks; - if (work_blocks < avail_sms) + if (work_blocks <= avail_sms * 2) { return work_blocks; } - int gpu_occupancy = sm_occupancy * avail_sms; - int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy; - return gpu_wavesets * gpu_occupancy; - + return fast_max(work_blocks, avail_sms * 4); } @@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK { CUTLASS_DEVICE int get_block_idx() const { - int block_idx = RematerializeBlockIdxX(); + // Remap the block indices for the first two waves of thread blocks if + // we have multi-occupancy and the grid constitutes four or more waves - int gpu_occupancy = avail_sms * sm_occupancy; + int block_idx = RematerializeBlockIdxX(); int num_blocks = device_num_blocks(); - int dest_sm, dest_wave; - - div_mod.sm_occupancy(dest_sm, dest_wave, block_idx); - + int dest_sm = block_idx / 2; + int dest_wave = block_idx % 2; int remapped_block_idx = dest_sm + (dest_wave * avail_sms); - // remapping the first gpu_occupancy blocks - if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy)) + if ((sm_occupancy > 1) && + (num_blocks >= avail_sms * 4) && + (block_idx < avail_sms * 2)) { block_idx = remapped_block_idx; }