Skip to content

Commit cc2330f

Browse files
committed
generalize warp indexing intrinsics and add coverage
1 parent 9bffa4a commit cc2330f

File tree

7 files changed

+395
-43
lines changed

7 files changed

+395
-43
lines changed

src/op/builtin.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,22 +219,22 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
219219
Integer(CallEffectKind::kOpaque));
220220

221221
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
222-
.set_num_inputs(0)
222+
.set_num_inputs(-1)
223223
.set_attr<TCallEffectKind>("TCallEffectKind",
224224
Integer(CallEffectKind::kPure));
225225

226226
TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
227-
.set_num_inputs(0)
227+
.set_num_inputs(-1)
228228
.set_attr<TCallEffectKind>("TCallEffectKind",
229229
Integer(CallEffectKind::kPure));
230230

231231
TIR_DEFINE_TL_BUILTIN(get_warp_idx)
232-
.set_num_inputs(0)
232+
.set_num_inputs(-1)
233233
.set_attr<TCallEffectKind>("TCallEffectKind",
234234
Integer(CallEffectKind::kPure));
235235

236236
TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
237-
.set_num_inputs(0)
237+
.set_num_inputs(-1)
238238
.set_attr<TCallEffectKind>("TCallEffectKind",
239239
Integer(CallEffectKind::kPure));
240240

src/op/builtin.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,31 +361,31 @@ TVM_DLL const Op &warpgroup_wait();
361361
/*!
362362
* \brief Return the canonical lane index for the calling thread.
363363
*
364-
* get_lane_idx()
364+
* get_lane_idx([warp_size])
365365
*
366366
*/
367367
TVM_DLL const Op &get_lane_idx();
368368

369369
/*!
370370
* \brief Return the canonical warp index, assuming converged threads.
371371
*
372-
* get_warp_idx_sync()
372+
* get_warp_idx_sync([warp_size])
373373
*
374374
*/
375375
TVM_DLL const Op &get_warp_idx_sync();
376376

377377
/*!
378378
* \brief Return the canonical warp index without synchronizing the warp.
379379
*
380-
* get_warp_idx()
380+
* get_warp_idx([warp_size])
381381
*
382382
*/
383383
TVM_DLL const Op &get_warp_idx();
384384

385385
/*!
386386
* \brief Return the canonical warp group index for converged threads.
387387
*
388-
* get_warp_group_idx()
388+
* get_warp_group_idx([warp_size, warps_per_group])
389389
*
390390
*/
391391
TVM_DLL const Op &get_warp_group_idx();

src/target/codegen_cuda.cc

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,13 +1969,40 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
19691969
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
19701970
op->args, true, os);
19711971
} else if (op->op.same_as(tl::get_lane_idx())) {
1972-
os << "tl::get_lane_idx()";
1972+
ICHECK_LE(op->args.size(), 1)
1973+
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
1974+
os << "tl::get_lane_idx(";
1975+
if (!op->args.empty()) {
1976+
os << PrintExpr(op->args[0]);
1977+
}
1978+
os << ")";
19731979
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
1974-
os << "tl::get_warp_idx_sync()";
1980+
ICHECK_LE(op->args.size(), 1)
1981+
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
1982+
os << "tl::get_warp_idx_sync(";
1983+
if (!op->args.empty()) {
1984+
os << PrintExpr(op->args[0]);
1985+
}
1986+
os << ")";
19751987
} else if (op->op.same_as(tl::get_warp_idx())) {
1976-
os << "tl::get_warp_idx()";
1988+
ICHECK_LE(op->args.size(), 1)
1989+
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
1990+
os << "tl::get_warp_idx(";
1991+
if (!op->args.empty()) {
1992+
os << PrintExpr(op->args[0]);
1993+
}
1994+
os << ")";
19771995
} else if (op->op.same_as(tl::get_warp_group_idx())) {
1978-
os << "tl::get_warp_group_idx()";
1996+
ICHECK_LE(op->args.size(), 2)
1997+
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
1998+
os << "tl::get_warp_group_idx(";
1999+
for (size_t i = 0; i < op->args.size(); ++i) {
2000+
if (i != 0) {
2001+
os << ", ";
2002+
}
2003+
os << PrintExpr(op->args[i]);
2004+
}
2005+
os << ")";
19792006
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
19802007
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
19812008
} else if (op->op.same_as(tl::initialize_descriptor())) {

src/tl_templates/cuda/intrin.h

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,53 @@
1010

1111
namespace tl {
1212

13-
TL_DEVICE int get_lane_idx() { return cutlass::canonical_lane_idx(); }
13+
namespace detail {
1414

15-
TL_DEVICE int get_warp_idx_sync() { return cutlass::canonical_warp_idx_sync(); }
15+
// Provide architecture-specific defaults so callers may omit arguments.
16+
TL_DEVICE constexpr int default_warp_size() {
17+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__)
18+
return 64;
19+
#else
20+
return 32;
21+
#endif
22+
}
23+
24+
TL_DEVICE constexpr int default_warps_per_group() { return 4; }
25+
26+
TL_DEVICE int linear_thread_idx_in_block() {
27+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
28+
return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
29+
#else
30+
return 0;
31+
#endif
32+
}
1633

17-
TL_DEVICE int get_warp_idx() { return cutlass::canonical_warp_idx(); }
34+
} // namespace detail
35+
36+
TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) {
37+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
38+
return detail::linear_thread_idx_in_block() % warp_size;
39+
}
40+
41+
TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) {
42+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
43+
return detail::linear_thread_idx_in_block() / warp_size;
44+
}
45+
46+
TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) {
47+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
48+
return detail::linear_thread_idx_in_block() / warp_size;
49+
}
1850

19-
TL_DEVICE int get_warp_group_idx() {
20-
return cutlass::canonical_warp_group_idx();
51+
TL_DEVICE int
52+
get_warp_group_idx(int warp_size = detail::default_warp_size(),
53+
int warps_per_group = detail::default_warps_per_group()) {
54+
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
55+
warps_per_group =
56+
warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group();
57+
int threads_per_group = warp_size * warps_per_group;
58+
threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size;
59+
return detail::linear_thread_idx_in_block() / threads_per_group;
2160
}
2261

2362
#if __CUDA_ARCH_LIST__ >= 900
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from typing import Optional
2+
3+
import tilelang.language as T
4+
import tilelang.testing
5+
import torch
6+
from tilelang.utils.target import check_hip_availability
7+
8+
_IS_HIP_AVAILABLE = check_hip_availability()
9+
_DEFAULT_WARPS_PER_GROUP = 4
10+
11+
12+
def _resolve_warp_size(warp_size: Optional[int]) -> int:
13+
if warp_size is not None:
14+
return int(warp_size)
15+
return 64 if _IS_HIP_AVAILABLE else 32
16+
17+
18+
def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
19+
if warps_per_group is not None:
20+
return int(warps_per_group)
21+
return _DEFAULT_WARPS_PER_GROUP
22+
23+
24+
@tilelang.jit(out_idx=[-1])
25+
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
26+
27+
@T.prim_func
28+
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
29+
with T.Kernel(1, threads=num_threads) as _:
30+
tx = T.get_thread_binding()
31+
A[tx] = T.get_lane_idx(warp_size)
32+
33+
return laneid_kernel
34+
35+
36+
@tilelang.jit(out_idx=[-1])
37+
def _get_warp_idx_sync_kernel(
38+
num_threads: int = 128, warp_size: Optional[int] = None
39+
):
40+
41+
@T.prim_func
42+
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
43+
with T.Kernel(1, threads=num_threads) as _:
44+
tx = T.get_thread_binding()
45+
A[tx] = T.get_warp_idx_sync(warp_size)
46+
47+
return warp_idx_sync_kernel
48+
49+
50+
@tilelang.jit(out_idx=[-1])
51+
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
52+
53+
@T.prim_func
54+
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
55+
with T.Kernel(1, threads=num_threads) as _:
56+
tx = T.get_thread_binding()
57+
A[tx] = T.get_warp_idx(warp_size)
58+
59+
return warp_idx_kernel
60+
61+
62+
@tilelang.jit(out_idx=[-1])
63+
def _get_warp_group_idx_kernel(
64+
num_threads: int = 128,
65+
warp_size: Optional[int] = None,
66+
warps_per_group: Optional[int] = None,
67+
):
68+
69+
@T.prim_func
70+
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
71+
with T.Kernel(1, threads=num_threads) as _:
72+
tx = T.get_thread_binding()
73+
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)
74+
75+
return warp_group_idx_kernel
76+
77+
78+
@tilelang.jit(out_idx=[-1])
79+
def _shuffle_elect_kernel(
80+
num_threads: int = 128, thread_extent: int = 64
81+
):
82+
83+
@T.prim_func
84+
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
85+
with T.Kernel(1, threads=num_threads) as _:
86+
tx = T.get_thread_binding()
87+
elected = T.shuffle_elect(thread_extent)
88+
A[tx] = elected
89+
90+
return shuffle_elect_kernel
91+
92+
93+
def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
94+
kernel = _get_laneid_kernel(num_threads, warp_size)
95+
A = kernel()
96+
print(kernel.get_kernel_source())
97+
print(A)
98+
expected_warp_size = _resolve_warp_size(warp_size)
99+
ref = torch.arange(
100+
num_threads, dtype=A.dtype, device=A.device
101+
) % expected_warp_size
102+
torch.testing.assert_close(A.cpu(), ref.cpu())
103+
return A
104+
105+
106+
def run_get_warp_idx_sync(
107+
num_threads: int = 128, warp_size: Optional[int] = None
108+
):
109+
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
110+
A = kernel()
111+
print(kernel.get_kernel_source())
112+
print(A)
113+
expected_warp_size = _resolve_warp_size(warp_size)
114+
ref = torch.arange(
115+
num_threads, dtype=A.dtype, device=A.device
116+
) // expected_warp_size
117+
torch.testing.assert_close(A.cpu(), ref.cpu())
118+
return A
119+
120+
121+
def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
122+
kernel = _get_warp_idx_kernel(num_threads, warp_size)
123+
A = kernel()
124+
print(kernel.get_kernel_source())
125+
print(A)
126+
expected_warp_size = _resolve_warp_size(warp_size)
127+
ref = torch.arange(
128+
num_threads, dtype=A.dtype, device=A.device
129+
) // expected_warp_size
130+
torch.testing.assert_close(A.cpu(), ref.cpu())
131+
return A
132+
133+
134+
def run_get_warp_group_idx(
135+
num_threads: int = 128,
136+
warp_size: Optional[int] = None,
137+
warps_per_group: Optional[int] = None,
138+
):
139+
kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group)
140+
A = kernel()
141+
print(kernel.get_kernel_source())
142+
print(A)
143+
expected_warp_size = _resolve_warp_size(warp_size)
144+
expected_warps_per_group = _resolve_warps_per_group(warps_per_group)
145+
threads_per_group = expected_warp_size * expected_warps_per_group
146+
if threads_per_group <= 0:
147+
raise ValueError("threads_per_group must be positive.")
148+
ref = torch.arange(
149+
num_threads, dtype=A.dtype, device=A.device
150+
) // threads_per_group
151+
torch.testing.assert_close(A.cpu(), ref.cpu())
152+
return A
153+
154+
155+
def run_shuffle_elect(
156+
num_threads: int = 128, thread_extent: int = 64
157+
):
158+
if thread_extent < 0:
159+
raise ValueError("thread_extent must be non-negative.")
160+
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
161+
A = kernel()
162+
print(kernel.get_kernel_source())
163+
print(A)
164+
indices = torch.arange(
165+
num_threads, device=A.device, dtype=torch.int64
166+
)
167+
if thread_extent == 0:
168+
mask = indices == 0
169+
elif thread_extent > 0:
170+
mask = (indices % thread_extent) == 0
171+
else:
172+
mask = torch.zeros_like(indices, dtype=torch.bool)
173+
ref = mask.to(dtype=A.dtype, device=A.device)
174+
torch.testing.assert_close(A.cpu(), ref.cpu())
175+
return A
176+
177+
178+
@tilelang.testing.requires_cuda
179+
def test_get_lane_idx_default():
180+
run_get_lane_id()
181+
182+
183+
@tilelang.testing.requires_cuda
184+
def test_get_lane_idx_custom():
185+
run_get_lane_id(num_threads=256, warp_size=64)
186+
187+
188+
@tilelang.testing.requires_cuda
189+
def test_get_warp_idx_sync_default():
190+
run_get_warp_idx_sync()
191+
192+
193+
@tilelang.testing.requires_cuda
194+
def test_get_warp_idx_sync_custom():
195+
run_get_warp_idx_sync(num_threads=256, warp_size=16)
196+
197+
198+
@tilelang.testing.requires_cuda
199+
def test_get_warp_idx_default():
200+
run_get_warp_idx()
201+
202+
203+
@tilelang.testing.requires_cuda
204+
def test_get_warp_idx_custom():
205+
run_get_warp_idx(num_threads=320, warp_size=20)
206+
207+
208+
@tilelang.testing.requires_cuda
209+
def test_get_warp_group_idx_default():
210+
run_get_warp_group_idx()
211+
212+
213+
@tilelang.testing.requires_cuda
214+
def test_get_warp_group_idx_custom():
215+
run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5)
216+
217+
218+
@tilelang.testing.requires_cuda
219+
def test_shuffle_elect_default():
220+
run_shuffle_elect(num_threads=256, thread_extent=64)
221+
222+
223+
@tilelang.testing.requires_cuda
224+
def test_shuffle_elect_block_leader():
225+
run_shuffle_elect(num_threads=128, thread_extent=0)
226+
227+
if __name__ == "__main__":
228+
tilelang.testing.main()
229+
# run_get_lane_id()

0 commit comments

Comments
 (0)