Skip to content

Commit b78d840

Browse files
[Language] Expose T.get_warp_idx_sync and T.shuffle_elect for efficient thread election (#989)
* Expose CUDA warp/lane intrinsics in TileLang frontend * generalize warp indexing intrinsics and add coverage * [Lint]: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 32ddc1a commit b78d840

File tree

6 files changed

+504
-3
lines changed

6 files changed

+504
-3
lines changed

src/op/builtin.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,26 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
218218
.set_attr<TCallEffectKind>("TCallEffectKind",
219219
Integer(CallEffectKind::kOpaque));
220220

221+
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
222+
.set_num_inputs(-1)
223+
.set_attr<TCallEffectKind>("TCallEffectKind",
224+
Integer(CallEffectKind::kPure));
225+
226+
TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
227+
.set_num_inputs(-1)
228+
.set_attr<TCallEffectKind>("TCallEffectKind",
229+
Integer(CallEffectKind::kPure));
230+
231+
TIR_DEFINE_TL_BUILTIN(get_warp_idx)
232+
.set_num_inputs(-1)
233+
.set_attr<TCallEffectKind>("TCallEffectKind",
234+
Integer(CallEffectKind::kPure));
235+
236+
TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
237+
.set_num_inputs(-1)
238+
.set_attr<TCallEffectKind>("TCallEffectKind",
239+
Integer(CallEffectKind::kPure));
240+
221241
TIR_DEFINE_TL_BUILTIN(wait_wgmma)
222242
.set_num_inputs(1)
223243
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,38 @@ TVM_DLL const Op &warpgroup_commit_batch();
358358
*/
359359
TVM_DLL const Op &warpgroup_wait();
360360

361+
/*!
362+
* \brief Return the canonical lane index for the calling thread.
363+
*
364+
* get_lane_idx([warp_size])
365+
*
366+
*/
367+
TVM_DLL const Op &get_lane_idx();
368+
369+
/*!
370+
* \brief Return the canonical warp index, assuming converged threads.
371+
*
372+
* get_warp_idx_sync([warp_size])
373+
*
374+
*/
375+
TVM_DLL const Op &get_warp_idx_sync();
376+
377+
/*!
378+
* \brief Return the canonical warp index without synchronizing the warp.
379+
*
380+
* get_warp_idx([warp_size])
381+
*
382+
*/
383+
TVM_DLL const Op &get_warp_idx();
384+
385+
/*!
386+
* \brief Return the canonical warp group index for converged threads.
387+
*
388+
* get_warp_group_idx([warp_size, warps_per_group])
389+
*
390+
*/
391+
TVM_DLL const Op &get_warp_group_idx();
392+
361393
/*!
362394
* \brief Wait the previous wgmma to finish
363395
*

src/target/codegen_cuda.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19681968
enable_sparse_gemm_ = true;
19691969
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
19701970
op->args, true, os);
1971+
} else if (op->op.same_as(tl::get_lane_idx())) {
1972+
ICHECK_LE(op->args.size(), 1)
1973+
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
1974+
os << "tl::get_lane_idx(";
1975+
if (!op->args.empty()) {
1976+
os << PrintExpr(op->args[0]);
1977+
}
1978+
os << ")";
1979+
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
1980+
ICHECK_LE(op->args.size(), 1)
1981+
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
1982+
os << "tl::get_warp_idx_sync(";
1983+
if (!op->args.empty()) {
1984+
os << PrintExpr(op->args[0]);
1985+
}
1986+
os << ")";
1987+
} else if (op->op.same_as(tl::get_warp_idx())) {
1988+
ICHECK_LE(op->args.size(), 1)
1989+
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
1990+
os << "tl::get_warp_idx(";
1991+
if (!op->args.empty()) {
1992+
os << PrintExpr(op->args[0]);
1993+
}
1994+
os << ")";
1995+
} else if (op->op.same_as(tl::get_warp_group_idx())) {
1996+
ICHECK_LE(op->args.size(), 2)
1997+
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
1998+
os << "tl::get_warp_group_idx(";
1999+
for (size_t i = 0; i < op->args.size(); ++i) {
2000+
if (i != 0) {
2001+
os << ", ";
2002+
}
2003+
os << PrintExpr(op->args[i]);
2004+
}
2005+
os << ")";
19712006
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
19722007
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
19732008
} else if (op->op.same_as(tl::initialize_descriptor())) {

src/tl_templates/cuda/intrin.h

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,65 @@
11
#pragma once
22

3+
#include "common.h"
4+
#include "cutlass/cutlass.h"
5+
36
#if __CUDA_ARCH_LIST__ >= 900
47
#include "cute/arch/cluster_sm90.hpp"
58
#include "cute/arch/mma_sm90_gmma.hpp"
6-
#include "cutlass/cutlass.h"
9+
#endif
710

811
namespace tl {
912

13+
namespace detail {
14+
15+
// Provide architecture-specific defaults so callers may omit arguments.
16+
TL_DEVICE constexpr int default_warp_size() {
17+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__)
18+
return 64;
19+
#else
20+
return 32;
21+
#endif
22+
}
23+
24+
TL_DEVICE constexpr int default_warps_per_group() { return 4; }
25+
26+
TL_DEVICE int linear_thread_idx_in_block() {
27+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
28+
return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
29+
#else
30+
return 0;
31+
#endif
32+
}
33+
34+
} // namespace detail
35+
36+
TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) {
37+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
38+
return detail::linear_thread_idx_in_block() % warp_size;
39+
}
40+
41+
TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) {
42+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
43+
return detail::linear_thread_idx_in_block() / warp_size;
44+
}
45+
46+
TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) {
47+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
48+
return detail::linear_thread_idx_in_block() / warp_size;
49+
}
50+
51+
TL_DEVICE int
52+
get_warp_group_idx(int warp_size = detail::default_warp_size(),
53+
int warps_per_group = detail::default_warps_per_group()) {
54+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
55+
warps_per_group =
56+
warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group();
57+
int threads_per_group = warp_size * warps_per_group;
58+
threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size;
59+
return detail::linear_thread_idx_in_block() / threads_per_group;
60+
}
61+
62+
#if __CUDA_ARCH_LIST__ >= 900
1063
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
1164
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
1265

@@ -61,5 +114,6 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
61114
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
62115
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
63116
}
64-
} // namespace tl
65117
#endif
118+
119+
} // namespace tl
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from typing import Optional
2+
3+
import tilelang.language as T
4+
import tilelang.testing
5+
import torch
6+
from tilelang.utils.target import check_hip_availability
7+
8+
_IS_HIP_AVAILABLE = check_hip_availability()
9+
_DEFAULT_WARPS_PER_GROUP = 4
10+
11+
12+
def _resolve_warp_size(warp_size: Optional[int]) -> int:
13+
if warp_size is not None:
14+
return int(warp_size)
15+
return 64 if _IS_HIP_AVAILABLE else 32
16+
17+
18+
def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
19+
if warps_per_group is not None:
20+
return int(warps_per_group)
21+
return _DEFAULT_WARPS_PER_GROUP
22+
23+
24+
@tilelang.jit(out_idx=[-1])
25+
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
26+
27+
@T.prim_func
28+
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
29+
with T.Kernel(1, threads=num_threads) as _:
30+
tx = T.get_thread_binding()
31+
A[tx] = T.get_lane_idx(warp_size)
32+
33+
return laneid_kernel
34+
35+
36+
@tilelang.jit(out_idx=[-1])
37+
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
38+
39+
@T.prim_func
40+
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
41+
with T.Kernel(1, threads=num_threads) as _:
42+
tx = T.get_thread_binding()
43+
A[tx] = T.get_warp_idx_sync(warp_size)
44+
45+
return warp_idx_sync_kernel
46+
47+
48+
@tilelang.jit(out_idx=[-1])
49+
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
50+
51+
@T.prim_func
52+
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
53+
with T.Kernel(1, threads=num_threads) as _:
54+
tx = T.get_thread_binding()
55+
A[tx] = T.get_warp_idx(warp_size)
56+
57+
return warp_idx_kernel
58+
59+
60+
@tilelang.jit(out_idx=[-1])
61+
def _get_warp_group_idx_kernel(
62+
num_threads: int = 128,
63+
warp_size: Optional[int] = None,
64+
warps_per_group: Optional[int] = None,
65+
):
66+
67+
@T.prim_func
68+
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
69+
with T.Kernel(1, threads=num_threads) as _:
70+
tx = T.get_thread_binding()
71+
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)
72+
73+
return warp_group_idx_kernel
74+
75+
76+
@tilelang.jit(out_idx=[-1])
77+
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
78+
79+
@T.prim_func
80+
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
81+
with T.Kernel(1, threads=num_threads) as _:
82+
tx = T.get_thread_binding()
83+
elected = T.shuffle_elect(thread_extent)
84+
A[tx] = elected
85+
86+
return shuffle_elect_kernel
87+
88+
89+
def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
90+
kernel = _get_laneid_kernel(num_threads, warp_size)
91+
A = kernel()
92+
print(kernel.get_kernel_source())
93+
print(A)
94+
expected_warp_size = _resolve_warp_size(warp_size)
95+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size
96+
torch.testing.assert_close(A.cpu(), ref.cpu())
97+
return A
98+
99+
100+
def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None):
101+
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
102+
A = kernel()
103+
print(kernel.get_kernel_source())
104+
print(A)
105+
expected_warp_size = _resolve_warp_size(warp_size)
106+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
107+
torch.testing.assert_close(A.cpu(), ref.cpu())
108+
return A
109+
110+
111+
def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
112+
kernel = _get_warp_idx_kernel(num_threads, warp_size)
113+
A = kernel()
114+
print(kernel.get_kernel_source())
115+
print(A)
116+
expected_warp_size = _resolve_warp_size(warp_size)
117+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
118+
torch.testing.assert_close(A.cpu(), ref.cpu())
119+
return A
120+
121+
122+
def run_get_warp_group_idx(
123+
num_threads: int = 128,
124+
warp_size: Optional[int] = None,
125+
warps_per_group: Optional[int] = None,
126+
):
127+
kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group)
128+
A = kernel()
129+
print(kernel.get_kernel_source())
130+
print(A)
131+
expected_warp_size = _resolve_warp_size(warp_size)
132+
expected_warps_per_group = _resolve_warps_per_group(warps_per_group)
133+
threads_per_group = expected_warp_size * expected_warps_per_group
134+
if threads_per_group <= 0:
135+
raise ValueError("threads_per_group must be positive.")
136+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group
137+
torch.testing.assert_close(A.cpu(), ref.cpu())
138+
return A
139+
140+
141+
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
142+
if thread_extent < 0:
143+
raise ValueError("thread_extent must be non-negative.")
144+
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
145+
A = kernel()
146+
print(kernel.get_kernel_source())
147+
print(A)
148+
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
149+
if thread_extent == 0:
150+
mask = indices == 0
151+
elif thread_extent > 0:
152+
mask = (indices % thread_extent) == 0
153+
else:
154+
mask = torch.zeros_like(indices, dtype=torch.bool)
155+
ref = mask.to(dtype=A.dtype, device=A.device)
156+
torch.testing.assert_close(A.cpu(), ref.cpu())
157+
return A
158+
159+
160+
@tilelang.testing.requires_cuda
161+
def test_get_lane_idx_default():
162+
run_get_lane_id()
163+
164+
165+
@tilelang.testing.requires_cuda
166+
def test_get_lane_idx_custom():
167+
run_get_lane_id(num_threads=256, warp_size=64)
168+
169+
170+
@tilelang.testing.requires_cuda
171+
def test_get_warp_idx_sync_default():
172+
run_get_warp_idx_sync()
173+
174+
175+
@tilelang.testing.requires_cuda
176+
def test_get_warp_idx_sync_custom():
177+
run_get_warp_idx_sync(num_threads=256, warp_size=16)
178+
179+
180+
@tilelang.testing.requires_cuda
181+
def test_get_warp_idx_default():
182+
run_get_warp_idx()
183+
184+
185+
@tilelang.testing.requires_cuda
186+
def test_get_warp_idx_custom():
187+
run_get_warp_idx(num_threads=320, warp_size=20)
188+
189+
190+
@tilelang.testing.requires_cuda
191+
def test_get_warp_group_idx_default():
192+
run_get_warp_group_idx()
193+
194+
195+
@tilelang.testing.requires_cuda
196+
def test_get_warp_group_idx_custom():
197+
run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5)
198+
199+
200+
@tilelang.testing.requires_cuda
201+
def test_shuffle_elect_default():
202+
run_shuffle_elect(num_threads=256, thread_extent=64)
203+
204+
205+
@tilelang.testing.requires_cuda
206+
def test_shuffle_elect_block_leader():
207+
run_shuffle_elect(num_threads=128, thread_extent=0)
208+
209+
210+
if __name__ == "__main__":
211+
tilelang.testing.main()
212+
# run_get_lane_id()

0 commit comments

Comments
 (0)