-
Notifications
You must be signed in to change notification settings - Fork 290
Closed
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel(hidden):
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden, ), dtype='float')
T.copy(x[pid, :], smem)
T.cumsum(smem)
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel(128)
print(kernel.get_kernel_source())
x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
kernel(x)As titled, generated CUDA:
extern "C" __global__ void __launch_bounds__(128, 1) buggy_kernel_kernel(float* __restrict__ x, int num_tokens) {
extern __shared__ __align__(1024) float smem[];
#pragma unroll
for (int i = 0; i < 1; ++i) {
smem[((int)threadIdx.x)] = x[((((int64_t)((int)blockIdx.x)) * (int64_t)128) + ((int64_t)((int)threadIdx.x)))];
}
tl::fence_proxy_async();
__syncthreads();
tl::CumSum1D<128, false>::run((&(smem[0])), (&(smem[0])), 128);
}Metadata
Metadata
Assignees
Labels
No labels