|  | 
|  | 1 | +from typing import Optional | 
|  | 2 | + | 
|  | 3 | +import tilelang.language as T | 
|  | 4 | +import tilelang.testing | 
|  | 5 | +import torch | 
|  | 6 | +from tilelang.utils.target import check_hip_availability | 
|  | 7 | + | 
|  | 8 | +_IS_HIP_AVAILABLE = check_hip_availability() | 
|  | 9 | +_DEFAULT_WARPS_PER_GROUP = 4 | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +def _resolve_warp_size(warp_size: Optional[int]) -> int: | 
|  | 13 | +    if warp_size is not None: | 
|  | 14 | +        return int(warp_size) | 
|  | 15 | +    return 64 if _IS_HIP_AVAILABLE else 32 | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: | 
|  | 19 | +    if warps_per_group is not None: | 
|  | 20 | +        return int(warps_per_group) | 
|  | 21 | +    return _DEFAULT_WARPS_PER_GROUP | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +@tilelang.jit(out_idx=[-1]) | 
|  | 25 | +def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 26 | + | 
|  | 27 | +    @T.prim_func | 
|  | 28 | +    def laneid_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 29 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 30 | +            tx = T.get_thread_binding() | 
|  | 31 | +            A[tx] = T.get_lane_idx(warp_size) | 
|  | 32 | + | 
|  | 33 | +    return laneid_kernel | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +@tilelang.jit(out_idx=[-1]) | 
|  | 37 | +def _get_warp_idx_sync_kernel( | 
|  | 38 | +    num_threads: int = 128, warp_size: Optional[int] = None | 
|  | 39 | +): | 
|  | 40 | + | 
|  | 41 | +    @T.prim_func | 
|  | 42 | +    def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 43 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 44 | +            tx = T.get_thread_binding() | 
|  | 45 | +            A[tx] = T.get_warp_idx_sync(warp_size) | 
|  | 46 | + | 
|  | 47 | +    return warp_idx_sync_kernel | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +@tilelang.jit(out_idx=[-1]) | 
|  | 51 | +def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 52 | + | 
|  | 53 | +    @T.prim_func | 
|  | 54 | +    def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 55 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 56 | +            tx = T.get_thread_binding() | 
|  | 57 | +            A[tx] = T.get_warp_idx(warp_size) | 
|  | 58 | + | 
|  | 59 | +    return warp_idx_kernel | 
|  | 60 | + | 
|  | 61 | + | 
|  | 62 | +@tilelang.jit(out_idx=[-1]) | 
|  | 63 | +def _get_warp_group_idx_kernel( | 
|  | 64 | +    num_threads: int = 128, | 
|  | 65 | +    warp_size: Optional[int] = None, | 
|  | 66 | +    warps_per_group: Optional[int] = None, | 
|  | 67 | +): | 
|  | 68 | + | 
|  | 69 | +    @T.prim_func | 
|  | 70 | +    def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 71 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 72 | +            tx = T.get_thread_binding() | 
|  | 73 | +            A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) | 
|  | 74 | + | 
|  | 75 | +    return warp_group_idx_kernel | 
|  | 76 | + | 
|  | 77 | + | 
|  | 78 | +@tilelang.jit(out_idx=[-1]) | 
|  | 79 | +def _shuffle_elect_kernel( | 
|  | 80 | +    num_threads: int = 128, thread_extent: int = 64 | 
|  | 81 | +): | 
|  | 82 | + | 
|  | 83 | +    @T.prim_func | 
|  | 84 | +    def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): | 
|  | 85 | +        with T.Kernel(1, threads=num_threads) as _: | 
|  | 86 | +            tx = T.get_thread_binding() | 
|  | 87 | +            elected = T.shuffle_elect(thread_extent) | 
|  | 88 | +            A[tx] = elected | 
|  | 89 | + | 
|  | 90 | +    return shuffle_elect_kernel | 
|  | 91 | + | 
|  | 92 | + | 
|  | 93 | +def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 94 | +    kernel = _get_laneid_kernel(num_threads, warp_size) | 
|  | 95 | +    A = kernel() | 
|  | 96 | +    print(kernel.get_kernel_source()) | 
|  | 97 | +    print(A) | 
|  | 98 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 99 | +    ref = torch.arange( | 
|  | 100 | +        num_threads, dtype=A.dtype, device=A.device | 
|  | 101 | +    ) % expected_warp_size | 
|  | 102 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 103 | +    return A | 
|  | 104 | + | 
|  | 105 | + | 
|  | 106 | +def run_get_warp_idx_sync( | 
|  | 107 | +    num_threads: int = 128, warp_size: Optional[int] = None | 
|  | 108 | +): | 
|  | 109 | +    kernel = _get_warp_idx_sync_kernel(num_threads, warp_size) | 
|  | 110 | +    A = kernel() | 
|  | 111 | +    print(kernel.get_kernel_source()) | 
|  | 112 | +    print(A) | 
|  | 113 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 114 | +    ref = torch.arange( | 
|  | 115 | +        num_threads, dtype=A.dtype, device=A.device | 
|  | 116 | +    ) // expected_warp_size | 
|  | 117 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 118 | +    return A | 
|  | 119 | + | 
|  | 120 | + | 
|  | 121 | +def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None): | 
|  | 122 | +    kernel = _get_warp_idx_kernel(num_threads, warp_size) | 
|  | 123 | +    A = kernel() | 
|  | 124 | +    print(kernel.get_kernel_source()) | 
|  | 125 | +    print(A) | 
|  | 126 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 127 | +    ref = torch.arange( | 
|  | 128 | +        num_threads, dtype=A.dtype, device=A.device | 
|  | 129 | +    ) // expected_warp_size | 
|  | 130 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 131 | +    return A | 
|  | 132 | + | 
|  | 133 | + | 
|  | 134 | +def run_get_warp_group_idx( | 
|  | 135 | +    num_threads: int = 128, | 
|  | 136 | +    warp_size: Optional[int] = None, | 
|  | 137 | +    warps_per_group: Optional[int] = None, | 
|  | 138 | +): | 
|  | 139 | +    kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group) | 
|  | 140 | +    A = kernel() | 
|  | 141 | +    print(kernel.get_kernel_source()) | 
|  | 142 | +    print(A) | 
|  | 143 | +    expected_warp_size = _resolve_warp_size(warp_size) | 
|  | 144 | +    expected_warps_per_group = _resolve_warps_per_group(warps_per_group) | 
|  | 145 | +    threads_per_group = expected_warp_size * expected_warps_per_group | 
|  | 146 | +    if threads_per_group <= 0: | 
|  | 147 | +        raise ValueError("threads_per_group must be positive.") | 
|  | 148 | +    ref = torch.arange( | 
|  | 149 | +        num_threads, dtype=A.dtype, device=A.device | 
|  | 150 | +    ) // threads_per_group | 
|  | 151 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 152 | +    return A | 
|  | 153 | + | 
|  | 154 | + | 
|  | 155 | +def run_shuffle_elect( | 
|  | 156 | +    num_threads: int = 128, thread_extent: int = 64 | 
|  | 157 | +): | 
|  | 158 | +    if thread_extent < 0: | 
|  | 159 | +        raise ValueError("thread_extent must be non-negative.") | 
|  | 160 | +    kernel = _shuffle_elect_kernel(num_threads, thread_extent) | 
|  | 161 | +    A = kernel() | 
|  | 162 | +    print(kernel.get_kernel_source()) | 
|  | 163 | +    print(A) | 
|  | 164 | +    indices = torch.arange( | 
|  | 165 | +        num_threads, device=A.device, dtype=torch.int64 | 
|  | 166 | +    ) | 
|  | 167 | +    if thread_extent == 0: | 
|  | 168 | +        mask = indices == 0 | 
|  | 169 | +    elif thread_extent > 0: | 
|  | 170 | +        mask = (indices % thread_extent) == 0 | 
|  | 171 | +    else: | 
|  | 172 | +        mask = torch.zeros_like(indices, dtype=torch.bool) | 
|  | 173 | +    ref = mask.to(dtype=A.dtype, device=A.device) | 
|  | 174 | +    torch.testing.assert_close(A.cpu(), ref.cpu()) | 
|  | 175 | +    return A | 
|  | 176 | + | 
|  | 177 | + | 
|  | 178 | +@tilelang.testing.requires_cuda | 
|  | 179 | +def test_get_lane_idx_default(): | 
|  | 180 | +    run_get_lane_id() | 
|  | 181 | + | 
|  | 182 | + | 
|  | 183 | +@tilelang.testing.requires_cuda | 
|  | 184 | +def test_get_lane_idx_custom(): | 
|  | 185 | +    run_get_lane_id(num_threads=256, warp_size=64) | 
|  | 186 | + | 
|  | 187 | + | 
|  | 188 | +@tilelang.testing.requires_cuda | 
|  | 189 | +def test_get_warp_idx_sync_default(): | 
|  | 190 | +    run_get_warp_idx_sync() | 
|  | 191 | + | 
|  | 192 | + | 
|  | 193 | +@tilelang.testing.requires_cuda | 
|  | 194 | +def test_get_warp_idx_sync_custom(): | 
|  | 195 | +    run_get_warp_idx_sync(num_threads=256, warp_size=16) | 
|  | 196 | + | 
|  | 197 | + | 
|  | 198 | +@tilelang.testing.requires_cuda | 
|  | 199 | +def test_get_warp_idx_default(): | 
|  | 200 | +    run_get_warp_idx() | 
|  | 201 | + | 
|  | 202 | + | 
|  | 203 | +@tilelang.testing.requires_cuda | 
|  | 204 | +def test_get_warp_idx_custom(): | 
|  | 205 | +    run_get_warp_idx(num_threads=320, warp_size=20) | 
|  | 206 | + | 
|  | 207 | + | 
|  | 208 | +@tilelang.testing.requires_cuda | 
|  | 209 | +def test_get_warp_group_idx_default(): | 
|  | 210 | +    run_get_warp_group_idx() | 
|  | 211 | + | 
|  | 212 | + | 
|  | 213 | +@tilelang.testing.requires_cuda | 
|  | 214 | +def test_get_warp_group_idx_custom(): | 
|  | 215 | +    run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5) | 
|  | 216 | + | 
|  | 217 | + | 
|  | 218 | +@tilelang.testing.requires_cuda | 
|  | 219 | +def test_shuffle_elect_default(): | 
|  | 220 | +    run_shuffle_elect(num_threads=256, thread_extent=64) | 
|  | 221 | + | 
|  | 222 | + | 
|  | 223 | +@tilelang.testing.requires_cuda | 
|  | 224 | +def test_shuffle_elect_block_leader(): | 
|  | 225 | +    run_shuffle_elect(num_threads=128, thread_extent=0) | 
|  | 226 | + | 
|  | 227 | +if __name__ == "__main__": | 
|  | 228 | +    tilelang.testing.main() | 
|  | 229 | +    # run_get_lane_id() | 
0 commit comments