|  | 
|  | 1 | +from typing import Optional | 
|  | 2 | + | 
|  | 3 | +import tilelang.language as T | 
|  | 4 | +import tilelang.testing | 
|  | 5 | +import torch | 
|  | 6 | +from tilelang.utils.target import check_hip_availability | 
|  | 7 | + | 
|  | 8 | +_IS_HIP_AVAILABLE = check_hip_availability() | 
|  | 9 | +_DEFAULT_WARPS_PER_GROUP = 4 | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +def _resolve_warp_size(warp_size: Optional[int]) -> int: | 
|  | 13 | +    if warp_size is not None: | 
|  | 14 | +        return int(warp_size) | 
|  | 15 | +    return 64 if _IS_HIP_AVAILABLE else 32 | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: | 
|  | 19 | +    if warps_per_group is not None: | 
|  | 20 | +        return int(warps_per_group) | 
|  | 21 | +    return _DEFAULT_WARPS_PER_GROUP | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +@tilelang.jit(out_idx=[-1]) | 
|  | 25 | +def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 26 | + | 
|  | 27 | +    @T.prim_func | 
|  | 28 | +    def laneid_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 29 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 30 | +            tx = T.get_thread_binding() | 
|  | 31 | +            A[tx] = T.get_lane_idx(warp_size) | 
|  | 32 | + | 
|  | 33 | +    return laneid_kernel | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +@tilelang.jit(out_idx=[-1]) | 
|  | 37 | +def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 38 | + | 
|  | 39 | +    @T.prim_func | 
|  | 40 | +    def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 41 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 42 | +            tx = T.get_thread_binding() | 
|  | 43 | +            A[tx] = T.get_warp_idx_sync(warp_size) | 
|  | 44 | + | 
|  | 45 | +    return warp_idx_sync_kernel | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +@tilelang.jit(out_idx=[-1]) | 
|  | 49 | +def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 50 | + | 
|  | 51 | +    @T.prim_func | 
|  | 52 | +    def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 53 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 54 | +            tx = T.get_thread_binding() | 
|  | 55 | +            A[tx] = T.get_warp_idx(warp_size) | 
|  | 56 | + | 
|  | 57 | +    return warp_idx_kernel | 
|  | 58 | + | 
|  | 59 | + | 
|  | 60 | +@tilelang.jit(out_idx=[-1]) | 
|  | 61 | +def _get_warp_group_idx_kernel( | 
|  | 62 | +    num_threads: int = 128, | 
|  | 63 | +    warp_size: Optional[int] = None, | 
|  | 64 | +    warps_per_group: Optional[int] = None, | 
|  | 65 | +): | 
|  | 66 | + | 
|  | 67 | +    @T.prim_func | 
|  | 68 | +    def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 69 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 70 | +            tx = T.get_thread_binding() | 
|  | 71 | +            A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) | 
|  | 72 | + | 
|  | 73 | +    return warp_group_idx_kernel | 
|  | 74 | + | 
|  | 75 | + | 
|  | 76 | +@tilelang.jit(out_idx=[-1]) | 
|  | 77 | +def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): | 
|  | 78 | + | 
|  | 79 | +    @T.prim_func | 
|  | 80 | +    def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 81 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 82 | +            tx = T.get_thread_binding() | 
|  | 83 | +            elected = T.shuffle_elect(thread_extent) | 
|  | 84 | +            A[tx] = elected | 
|  | 85 | + | 
|  | 86 | +    return shuffle_elect_kernel | 
|  | 87 | + | 
|  | 88 | + | 
|  | 89 | +def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 90 | +    kernel = _get_laneid_kernel(num_threads, warp_size) | 
|  | 91 | +    A = kernel() | 
|  | 92 | +    print(kernel.get_kernel_source()) | 
|  | 93 | +    print(A) | 
|  | 94 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 95 | +    ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size | 
|  | 96 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 97 | +    return A | 
|  | 98 | + | 
|  | 99 | + | 
|  | 100 | +def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 101 | +    kernel = _get_warp_idx_sync_kernel(num_threads, warp_size) | 
|  | 102 | +    A = kernel() | 
|  | 103 | +    print(kernel.get_kernel_source()) | 
|  | 104 | +    print(A) | 
|  | 105 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 106 | +    ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size | 
|  | 107 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 108 | +    return A | 
|  | 109 | + | 
|  | 110 | + | 
|  | 111 | +def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 112 | +    kernel = _get_warp_idx_kernel(num_threads, warp_size) | 
|  | 113 | +    A = kernel() | 
|  | 114 | +    print(kernel.get_kernel_source()) | 
|  | 115 | +    print(A) | 
|  | 116 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 117 | +    ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size | 
|  | 118 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 119 | +    return A | 
|  | 120 | + | 
|  | 121 | + | 
|  | 122 | +def run_get_warp_group_idx( | 
|  | 123 | +    num_threads: int = 128, | 
|  | 124 | +    warp_size: Optional[int] = None, | 
|  | 125 | +    warps_per_group: Optional[int] = None, | 
|  | 126 | +): | 
|  | 127 | +    kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group) | 
|  | 128 | +    A = kernel() | 
|  | 129 | +    print(kernel.get_kernel_source()) | 
|  | 130 | +    print(A) | 
|  | 131 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 132 | +    expected_warps_per_group = _resolve_warps_per_group(warps_per_group) | 
|  | 133 | +    threads_per_group = expected_warp_size * expected_warps_per_group | 
|  | 134 | +    if threads_per_group <= 0: | 
|  | 135 | +        raise ValueError("threads_per_group must be positive.") | 
|  | 136 | +    ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group | 
|  | 137 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 138 | +    return A | 
|  | 139 | + | 
|  | 140 | + | 
|  | 141 | +def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64): | 
|  | 142 | +    if thread_extent < 0: | 
|  | 143 | +        raise ValueError("thread_extent must be non-negative.") | 
|  | 144 | +    kernel = _shuffle_elect_kernel(num_threads, thread_extent) | 
|  | 145 | +    A = kernel() | 
|  | 146 | +    print(kernel.get_kernel_source()) | 
|  | 147 | +    print(A) | 
|  | 148 | +    indices = torch.arange(num_threads, device=A.device, dtype=torch.int64) | 
|  | 149 | +    if thread_extent == 0: | 
|  | 150 | +        mask = indices == 0 | 
|  | 151 | +    elif thread_extent > 0: | 
|  | 152 | +        mask = (indices % thread_extent) == 0 | 
|  | 153 | +    else: | 
|  | 154 | +        mask = torch.zeros_like(indices, dtype=torch.bool) | 
|  | 155 | +    ref = mask.to(dtype=A.dtype, device=A.device) | 
|  | 156 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 157 | +    return A | 
|  | 158 | + | 
|  | 159 | + | 
|  | 160 | +@tilelang.testing.requires_cuda | 
|  | 161 | +def test_get_lane_idx_default(): | 
|  | 162 | +    run_get_lane_id() | 
|  | 163 | + | 
|  | 164 | + | 
|  | 165 | +@tilelang.testing.requires_cuda | 
|  | 166 | +def test_get_lane_idx_custom(): | 
|  | 167 | +    run_get_lane_id(num_threads=256, warp_size=64) | 
|  | 168 | + | 
|  | 169 | + | 
|  | 170 | +@tilelang.testing.requires_cuda | 
|  | 171 | +def test_get_warp_idx_sync_default(): | 
|  | 172 | +    run_get_warp_idx_sync() | 
|  | 173 | + | 
|  | 174 | + | 
|  | 175 | +@tilelang.testing.requires_cuda | 
|  | 176 | +def test_get_warp_idx_sync_custom(): | 
|  | 177 | +    run_get_warp_idx_sync(num_threads=256, warp_size=16) | 
|  | 178 | + | 
|  | 179 | + | 
|  | 180 | +@tilelang.testing.requires_cuda | 
|  | 181 | +def test_get_warp_idx_default(): | 
|  | 182 | +    run_get_warp_idx() | 
|  | 183 | + | 
|  | 184 | + | 
|  | 185 | +@tilelang.testing.requires_cuda | 
|  | 186 | +def test_get_warp_idx_custom(): | 
|  | 187 | +    run_get_warp_idx(num_threads=320, warp_size=20) | 
|  | 188 | + | 
|  | 189 | + | 
|  | 190 | +@tilelang.testing.requires_cuda | 
|  | 191 | +def test_get_warp_group_idx_default(): | 
|  | 192 | +    run_get_warp_group_idx() | 
|  | 193 | + | 
|  | 194 | + | 
|  | 195 | +@tilelang.testing.requires_cuda | 
|  | 196 | +def test_get_warp_group_idx_custom(): | 
|  | 197 | +    run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5) | 
|  | 198 | + | 
|  | 199 | + | 
|  | 200 | +@tilelang.testing.requires_cuda | 
|  | 201 | +def test_shuffle_elect_default(): | 
|  | 202 | +    run_shuffle_elect(num_threads=256, thread_extent=64) | 
|  | 203 | + | 
|  | 204 | + | 
|  | 205 | +@tilelang.testing.requires_cuda | 
|  | 206 | +def test_shuffle_elect_block_leader(): | 
|  | 207 | +    run_shuffle_elect(num_threads=128, thread_extent=0) | 
|  | 208 | + | 
|  | 209 | + | 
|  | 210 | +if __name__ == "__main__": | 
|  | 211 | +    tilelang.testing.main() | 
|  | 212 | +    # run_get_lane_id() | 
0 commit comments