-
Notifications
You must be signed in to change notification settings - Fork 478
Description
Required prerequisites
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Motivation
T.reduce behaves poorly when the layout requires each thread to hold multiple reduced values. See the program below.
The fragment layout is satisfactory, since 8 * sizeof (bfloat16) = 16 bytes is an excellent reading width from shared memory. However, the T.reduce code does not leverage this advantage. Instead, it gets the elements from shared memory one by one and calculates the 8 max values sequentially. As we know, tl::AllReduce leverages a butterfly reduction algorithm; there are many thread syncs in the path. This simple piece of code has 8 * log(256 / 32) = 24 thread syncs! I suggest calculating these max values in parallel, i.e., only one butterfly reduction path, with 8 values done in parallel.
One may argue that parallel reduction may cost more shared memory. However, if you see the issue #1761, the issue suggests that each warp, instead of each thread, holds one copy of all the values to reduce in shared memory. The memory saved that way is just good for parallel reduction here!
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_sample_kernel():
@T.prim_func
def sample_kernel(a: T.Tensor([128, 64], T.bfloat16)):
with T.Kernel(1, threads=256):
a_shared = T.alloc_shared([128, 64], T.bfloat16)
amax_fragment = T.alloc_fragment([64, ], T.bfloat16)
T.copy(a, a_shared)
T.reduce_max(a_shared, amax_fragment, dim=0)
return sample_kernel
kernel = get_sample_kernel()
print(kernel.get_kernel_source())#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif
extern "C" __global__ void sample_kernel_kernel(const bfloat16_t* __restrict__ a);
extern "C" __global__ void __launch_bounds__(256, 1) sample_kernel_kernel(const bfloat16_t* __restrict__ a) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
bfloat16_t a_shared_frag[32];
bfloat16_t amax_fragment[8];
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(uint4*)(((bfloat16_t*)buf_dyn_shmem) + ((i * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(a + ((i * 2048) + (((int)threadIdx.x) * 8)));
}
#pragma unroll
for (int i_1 = 0; i_1 < 4; ++i_1) {
*(uint4*)(a_shared_frag + (i_1 * 8)) = *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + ((i_1 * 2048) + (((int)threadIdx.x) * 8)));
}
__syncthreads();
#pragma unroll
for (int i_2 = 0; i_2 < 8; ++i_2) {
amax_fragment[i_2] = -std::numeric_limits<bfloat16_t>::infinity();
#pragma unroll
for (int rv = 0; rv < 4; ++rv) {
amax_fragment[i_2] = cutlass::bfloat16_t(__hmax((amax_fragment[i_2]).to_nv_bfloat16(), (a_shared_frag[((rv * 8) + i_2)]).to_nv_bfloat16()));
}
amax_fragment[i_2] = tl::AllReduce<tl::MaxOp, 256, 8, 0, tl::NamedBarrier<256>>::run(amax_fragment[i_2], (&(((bfloat16_t*)buf_dyn_shmem)[0])));
}
}Solution
No response
Alternatives
No response
Additional context
No response