|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | + |
| 4 | +import tilelang |
| 5 | +import tilelang.testing |
| 6 | +import tilelang.language as T |
| 7 | + |
| 8 | + |
| 9 | +@tilelang.jit |
| 10 | +def get_kernel(reduce_op: str, dtype: str): |
| 11 | + |
| 12 | + assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] |
| 13 | + |
| 14 | + @T.prim_func |
| 15 | + def main( |
| 16 | + x: T.Tensor((32), dtype) |
| 17 | + |
| 18 | + ): |
| 19 | + with T.Kernel(1, threads=32): |
| 20 | + tx = T.get_thread_binding(0) |
| 21 | + local_val = T.alloc_local([1], dtype) |
| 22 | + local_val[0] = x[tx] |
| 23 | + reduced_val = T.alloc_local([1], dtype) |
| 24 | + if reduce_op == "sum": |
| 25 | + reduced_val[0] = T.warp_reduce_sum(local_val[0]) |
| 26 | + elif reduce_op == "max": |
| 27 | + reduced_val[0] = T.warp_reduce_max(local_val[0]) |
| 28 | + elif reduce_op == "min": |
| 29 | + reduced_val[0] = T.warp_reduce_min(local_val[0]) |
| 30 | + elif reduce_op == "bitand": |
| 31 | + reduced_val[0] = T.warp_reduce_bitand(local_val[0]) |
| 32 | + elif reduce_op == "bitor": |
| 33 | + reduced_val[0] = T.warp_reduce_bitor(local_val[0]) |
| 34 | + x[tx] = reduced_val[0] |
| 35 | + return main |
| 36 | + |
| 37 | + |
| 38 | +def test_warp_reduce_sum(): |
| 39 | + a = torch.randn((32,), dtype=torch.float32, device='cuda') |
| 40 | + kernel = get_kernel('sum', 'float32') |
| 41 | + ref = torch.full_like(a, a.sum()) |
| 42 | + kernel(a) |
| 43 | + torch.testing.assert_close(a, ref) |
| 44 | + |
| 45 | + |
| 46 | +def test_warp_reduce_max(): |
| 47 | + a = torch.randn((32,), dtype=torch.float32, device='cuda') |
| 48 | + kernel = get_kernel("max", 'float32') |
| 49 | + print(kernel.get_kernel_source()) |
| 50 | + ref = torch.full_like(a, a.max()) |
| 51 | + kernel(a) |
| 52 | + torch.testing.assert_close(a, ref) |
| 53 | + |
| 54 | + |
| 55 | +def test_warp_reduce_min(): |
| 56 | + a = torch.randn((32,), dtype=torch.float32, device='cuda') |
| 57 | + kernel = get_kernel("min", 'float32') |
| 58 | + ref = torch.full_like(a, a.min()) |
| 59 | + kernel(a) |
| 60 | + torch.testing.assert_close(a, ref) |
| 61 | + |
| 62 | + |
| 63 | +def test_warp_reduce_bitand(): |
| 64 | + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') |
| 65 | + kernel = get_kernel("bitand", 'int32') |
| 66 | + ref_val = a[0] |
| 67 | + for i in range(1, a.shape[0]): |
| 68 | + ref_val = ref_val & a[i] |
| 69 | + ref = torch.full_like(a, ref_val) |
| 70 | + kernel(a) |
| 71 | + torch.testing.assert_close(a, ref) |
| 72 | + |
| 73 | + |
| 74 | +def test_warp_reduce_bitor(): |
| 75 | + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') |
| 76 | + kernel = get_kernel("bitor", 'int32') |
| 77 | + ref_val = a[0] |
| 78 | + for i in range(1, a.shape[0]): |
| 79 | + ref_val = ref_val | a[i] |
| 80 | + ref = torch.full_like(a, ref_val) |
| 81 | + kernel(a) |
| 82 | + torch.testing.assert_close(a, ref) |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + tilelang.testing.main() |
0 commit comments