Skip to content
20 changes: 20 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,26 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
32 changes: 32 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,38 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();

/*!
* \brief Return the canonical lane index for the calling thread.
*
* get_lane_idx([warp_size])
*
*/
TVM_DLL const Op &get_lane_idx();

/*!
* \brief Return the canonical warp index, assuming converged threads.
*
* get_warp_idx_sync([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx_sync();

/*!
* \brief Return the canonical warp index without synchronizing the warp.
*
* get_warp_idx([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx();

/*!
* \brief Return the canonical warp group index for converged threads.
*
* get_warp_group_idx([warp_size, warps_per_group])
*
*/
TVM_DLL const Op &get_warp_group_idx();

/*!
* \brief Wait the previous wgmma to finish
*
Expand Down
35 changes: 35 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
os << "tl::get_lane_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
os << "tl::get_warp_idx_sync(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
os << "tl::get_warp_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_group_idx())) {
ICHECK_LE(op->args.size(), 2)
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
os << "tl::get_warp_group_idx(";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << PrintExpr(op->args[i]);
}
os << ")";
Comment on lines +1971 to +2005
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add missing include for tl intrinsics to avoid undefined references

Emission for tl::get_lane_idx / tl::get_warp_idx(_sync) / tl::get_warp_group_idx looks correct and arity checks are fine. However, there’s no include for the tl intrinsics header; this can fail to compile when these symbols aren’t already brought in indirectly.

Include the header alongside other tl headers in Finish():

   decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
   if (enable_sparse_gemm_) {
     decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
   }
   decl_stream << "#include <tl_templates/cuda/copy.h>\n";
   decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
   decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
   decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
   decl_stream << "#include <tl_templates/cuda/debug.h>\n";
+  decl_stream << "#include <tl_templates/cuda/intrin.h>\n";
   decl_stream << "#ifdef ENABLE_BF16\n";
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
os << "tl::get_lane_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
os << "tl::get_warp_idx_sync(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
os << "tl::get_warp_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_group_idx())) {
ICHECK_LE(op->args.size(), 2)
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
os << "tl::get_warp_group_idx(";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << PrintExpr(op->args[i]);
}
os << ")";
// In Finish(), alongside the other TL headers:
decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
}
decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
decl_stream << "#include <tl_templates/cuda/debug.h>\n";
decl_stream << "#include <tl_templates/cuda/intrin.h>\n";
decl_stream << "#ifdef ENABLE_BF16\n";
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1971 to 2005, the code emits calls to
tl::get_lane_idx, tl::get_warp_idx(_sync) and tl::get_warp_group_idx but the
translation layer intrinsics header isn’t being included in Finish(), which can
cause undefined reference/compile errors; add an #include for the tl intrinsics
header (the header that declares
tl::get_lane_idx/get_warp_idx/get_warp_group_idx) alongside the other tl headers
in the Finish() function so the symbols are declared when these calls are
emitted.

} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) {
Expand Down
58 changes: 56 additions & 2 deletions src/tl_templates/cuda/intrin.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
#pragma once

#include "common.h"
#include "cutlass/cutlass.h"

#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cutlass/cutlass.h"
#endif

namespace tl {

namespace detail {

// Provide architecture-specific defaults so callers may omit arguments.
TL_DEVICE constexpr int default_warp_size() {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}

TL_DEVICE constexpr int default_warps_per_group() { return 4; }

TL_DEVICE int linear_thread_idx_in_block() {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
#else
return 0;
#endif
}

} // namespace detail

TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() % warp_size;
}

TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}

TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}

TL_DEVICE int
get_warp_group_idx(int warp_size = detail::default_warp_size(),
int warps_per_group = detail::default_warps_per_group()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
warps_per_group =
warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group();
int threads_per_group = warp_size * warps_per_group;
threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size;
return detail::linear_thread_idx_in_block() / threads_per_group;
}

#if __CUDA_ARCH_LIST__ >= 900
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }

Expand Down Expand Up @@ -61,5 +114,6 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
#endif

} // namespace tl
212 changes: 212 additions & 0 deletions testing/python/language/test_tilelang_language_get_warp_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Optional

import tilelang.language as T
import tilelang.testing
import torch
from tilelang.utils.target import check_hip_availability

_IS_HIP_AVAILABLE = check_hip_availability()
_DEFAULT_WARPS_PER_GROUP = 4


def _resolve_warp_size(warp_size: Optional[int]) -> int:
if warp_size is not None:
return int(warp_size)
return 64 if _IS_HIP_AVAILABLE else 32


def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
if warps_per_group is not None:
return int(warps_per_group)
return _DEFAULT_WARPS_PER_GROUP


@tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):

@T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_lane_idx(warp_size)

return laneid_kernel


@tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):

@T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx_sync(warp_size)

return warp_idx_sync_kernel


@tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):

@T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx(warp_size)

return warp_idx_kernel


@tilelang.jit(out_idx=[-1])
def _get_warp_group_idx_kernel(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):

@T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)

return warp_group_idx_kernel


@tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):

@T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
elected = T.shuffle_elect(thread_extent)
A[tx] = elected

return shuffle_elect_kernel


def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_laneid_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A


def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A


def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A


def run_get_warp_group_idx(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
expected_warps_per_group = _resolve_warps_per_group(warps_per_group)
threads_per_group = expected_warp_size * expected_warps_per_group
if threads_per_group <= 0:
raise ValueError("threads_per_group must be positive.")
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group
torch.testing.assert_close(A.cpu(), ref.cpu())
return A


def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
if thread_extent < 0:
raise ValueError("thread_extent must be non-negative.")
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
A = kernel()
print(kernel.get_kernel_source())
print(A)
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
if thread_extent == 0:
mask = indices == 0
elif thread_extent > 0:
mask = (indices % thread_extent) == 0
else:
mask = torch.zeros_like(indices, dtype=torch.bool)
ref = mask.to(dtype=A.dtype, device=A.device)
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
Comment on lines +141 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unreachable dead code.

Lines 153-154 are unreachable because:

  • Line 142 raises an exception if thread_extent < 0
  • Line 149 handles thread_extent == 0
  • Line 151 handles thread_extent > 0

The else branch can never execute.

Apply this diff to remove the dead code:

     if thread_extent == 0:
         mask = indices == 0
     elif thread_extent > 0:
         mask = (indices % thread_extent) == 0
-    else:
-        mask = torch.zeros_like(indices, dtype=torch.bool)
     ref = mask.to(dtype=A.dtype, device=A.device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
if thread_extent < 0:
raise ValueError("thread_extent must be non-negative.")
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
A = kernel()
print(kernel.get_kernel_source())
print(A)
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
if thread_extent == 0:
mask = indices == 0
elif thread_extent > 0:
mask = (indices % thread_extent) == 0
else:
mask = torch.zeros_like(indices, dtype=torch.bool)
ref = mask.to(dtype=A.dtype, device=A.device)
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
if thread_extent < 0:
raise ValueError("thread_extent must be non-negative.")
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
A = kernel()
print(kernel.get_kernel_source())
print(A)
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
if thread_extent == 0:
mask = indices == 0
elif thread_extent > 0:
mask = (indices % thread_extent) == 0
ref = mask.to(dtype=A.dtype, device=A.device)
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
🧰 Tools
🪛 Ruff (0.13.3)

143-143: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_get_warp_info.py around lines
141 to 157, remove the unreachable final else branch (lines 153-154) that sets
mask to torch.zeros_like(...) because thread_extent < 0 is already prevented by
the early ValueError and the other branches handle ==0 and >0; simply delete
that else block so mask is only set in the existing thread_extent == 0 and
thread_extent > 0 branches (or convert the second branch to an else if you
prefer), leaving mask always defined before building ref.



@tilelang.testing.requires_cuda
def test_get_lane_idx_default():
run_get_lane_id()


@tilelang.testing.requires_cuda
def test_get_lane_idx_custom():
run_get_lane_id(num_threads=256, warp_size=64)


@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_default():
run_get_warp_idx_sync()


@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_custom():
run_get_warp_idx_sync(num_threads=256, warp_size=16)


@tilelang.testing.requires_cuda
def test_get_warp_idx_default():
run_get_warp_idx()


@tilelang.testing.requires_cuda
def test_get_warp_idx_custom():
run_get_warp_idx(num_threads=320, warp_size=20)


@tilelang.testing.requires_cuda
def test_get_warp_group_idx_default():
run_get_warp_group_idx()


@tilelang.testing.requires_cuda
def test_get_warp_group_idx_custom():
run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5)


@tilelang.testing.requires_cuda
def test_shuffle_elect_default():
run_shuffle_elect(num_threads=256, thread_extent=64)


@tilelang.testing.requires_cuda
def test_shuffle_elect_block_leader():
run_shuffle_elect(num_threads=128, thread_extent=0)


if __name__ == "__main__":
tilelang.testing.main()
# run_get_lane_id()
Loading