-
Notifications
You must be signed in to change notification settings - Fork 333
Closed
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel(hidden):
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden, ), dtype='float')
T.copy(x[pid, :], smem)
local_var = T.sum(smem) # As titled
T.print(local_var)
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel(128)
print(kernel.get_kernel_source())
x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
kernel(x)As demonstrated.
Metadata
Metadata
Assignees
Labels
No labels