- 
        Couldn't load subscription status. 
- Fork 286
          [Language] Expose T.get_warp_idx_sync and T.shuffle_elect for efficient thread election
          #989
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9bffa4a
              cc2330f
              cebc6ff
              2d38bd4
              45adbaf
              65b77aa
              3616825
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,212 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.testing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tilelang.utils.target import check_hip_availability | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _IS_HIP_AVAILABLE = check_hip_availability() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _DEFAULT_WARPS_PER_GROUP = 4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_warp_size(warp_size: Optional[int]) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if warp_size is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return int(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 64 if _IS_HIP_AVAILABLE else 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if warps_per_group is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return int(warps_per_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _DEFAULT_WARPS_PER_GROUP | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.jit(out_idx=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def laneid_kernel(A: T.Tensor((num_threads,), "int32")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(1, threads=num_threads) as _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tx = T.get_thread_binding() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A[tx] = T.get_lane_idx(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return laneid_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.jit(out_idx=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(1, threads=num_threads) as _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tx = T.get_thread_binding() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A[tx] = T.get_warp_idx_sync(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return warp_idx_sync_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.jit(out_idx=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(1, threads=num_threads) as _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tx = T.get_thread_binding() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A[tx] = T.get_warp_idx(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return warp_idx_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.jit(out_idx=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_warp_group_idx_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_threads: int = 128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warp_size: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warps_per_group: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(1, threads=num_threads) as _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tx = T.get_thread_binding() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return warp_group_idx_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.jit(out_idx=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(1, threads=num_threads) as _: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tx = T.get_thread_binding() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elected = T.shuffle_elect(thread_extent) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A[tx] = elected | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return shuffle_elect_kernel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _get_laneid_kernel(num_threads, warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = kernel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected_warp_size = _resolve_warp_size(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(A.cpu(), ref.cpu()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _get_warp_idx_sync_kernel(num_threads, warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = kernel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected_warp_size = _resolve_warp_size(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(A.cpu(), ref.cpu()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _get_warp_idx_kernel(num_threads, warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = kernel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected_warp_size = _resolve_warp_size(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(A.cpu(), ref.cpu()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_get_warp_group_idx( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_threads: int = 128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warp_size: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warps_per_group: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = kernel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected_warp_size = _resolve_warp_size(warp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expected_warps_per_group = _resolve_warps_per_group(warps_per_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads_per_group = expected_warp_size * expected_warps_per_group | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if threads_per_group <= 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("threads_per_group must be positive.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(A.cpu(), ref.cpu()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if thread_extent < 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("thread_extent must be non-negative.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = _shuffle_elect_kernel(num_threads, thread_extent) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A = kernel() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(kernel.get_kernel_source()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(A) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| indices = torch.arange(num_threads, device=A.device, dtype=torch.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if thread_extent == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = indices == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif thread_extent > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = (indices % thread_extent) == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = torch.zeros_like(indices, dtype=torch.bool) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ref = mask.to(dtype=A.dtype, device=A.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(A.cpu(), ref.cpu()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return A | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +141
     to 
      +157
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unreachable dead code. Lines 153-154 are unreachable because: 
 The else branch can never execute. Apply this diff to remove the dead code:      if thread_extent == 0:
         mask = indices == 0
     elif thread_extent > 0:
         mask = (indices % thread_extent) == 0
-    else:
-        mask = torch.zeros_like(indices, dtype=torch.bool)
     ref = mask.to(dtype=A.dtype, device=A.device)📝 Committable suggestion
 
        Suggested change
       
 🧰 Tools🪛 Ruff (0.13.3)143-143: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_lane_idx_default(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_lane_id() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_lane_idx_custom(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_lane_id(num_threads=256, warp_size=64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_idx_sync_default(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_idx_sync() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_idx_sync_custom(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_idx_sync(num_threads=256, warp_size=16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_idx_default(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_idx() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_idx_custom(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_idx(num_threads=320, warp_size=20) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_group_idx_default(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_group_idx() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_get_warp_group_idx_custom(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_shuffle_elect_default(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_shuffle_elect(num_threads=256, thread_extent=64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_shuffle_elect_block_leader(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_shuffle_elect(num_threads=128, thread_extent=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.testing.main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # run_get_lane_id() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing include for tl intrinsics to avoid undefined references
Emission for tl::get_lane_idx / tl::get_warp_idx(_sync) / tl::get_warp_group_idx looks correct and arity checks are fine. However, there’s no include for the tl intrinsics header; this can fail to compile when these symbols aren’t already brought in indirectly.
Include the header alongside other tl headers in Finish():
decl_stream << "#include <tl_templates/cuda/gemm.h>\n"; if (enable_sparse_gemm_) { decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n"; } decl_stream << "#include <tl_templates/cuda/copy.h>\n"; decl_stream << "#include <tl_templates/cuda/reduce.h>\n"; decl_stream << "#include <tl_templates/cuda/ldsm.h>\n"; decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n"; decl_stream << "#include <tl_templates/cuda/debug.h>\n"; + decl_stream << "#include <tl_templates/cuda/intrin.h>\n"; decl_stream << "#ifdef ENABLE_BF16\n";📝 Committable suggestion
🤖 Prompt for AI Agents