Skip to content

Commit cebc6ff

Browse files
[Lint]: [pre-commit.ci] auto fixes [...]
1 parent cc2330f commit cebc6ff

File tree

1 file changed

+10
-27
lines changed

1 file changed

+10
-27
lines changed

testing/python/language/test_tilelang_language_get_warp_info.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
3434

3535

3636
@tilelang.jit(out_idx=[-1])
37-
def _get_warp_idx_sync_kernel(
38-
num_threads: int = 128, warp_size: Optional[int] = None
39-
):
37+
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
4038

4139
@T.prim_func
4240
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
@@ -76,9 +74,7 @@ def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
7674

7775

7876
@tilelang.jit(out_idx=[-1])
79-
def _shuffle_elect_kernel(
80-
num_threads: int = 128, thread_extent: int = 64
81-
):
77+
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
8278

8379
@T.prim_func
8480
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
@@ -96,24 +92,18 @@ def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
9692
print(kernel.get_kernel_source())
9793
print(A)
9894
expected_warp_size = _resolve_warp_size(warp_size)
99-
ref = torch.arange(
100-
num_threads, dtype=A.dtype, device=A.device
101-
) % expected_warp_size
95+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size
10296
torch.testing.assert_close(A.cpu(), ref.cpu())
10397
return A
10498

10599

106-
def run_get_warp_idx_sync(
107-
num_threads: int = 128, warp_size: Optional[int] = None
108-
):
100+
def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None):
109101
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
110102
A = kernel()
111103
print(kernel.get_kernel_source())
112104
print(A)
113105
expected_warp_size = _resolve_warp_size(warp_size)
114-
ref = torch.arange(
115-
num_threads, dtype=A.dtype, device=A.device
116-
) // expected_warp_size
106+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
117107
torch.testing.assert_close(A.cpu(), ref.cpu())
118108
return A
119109

@@ -124,9 +114,7 @@ def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
124114
print(kernel.get_kernel_source())
125115
print(A)
126116
expected_warp_size = _resolve_warp_size(warp_size)
127-
ref = torch.arange(
128-
num_threads, dtype=A.dtype, device=A.device
129-
) // expected_warp_size
117+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
130118
torch.testing.assert_close(A.cpu(), ref.cpu())
131119
return A
132120

@@ -145,25 +133,19 @@ def run_get_warp_group_idx(
145133
threads_per_group = expected_warp_size * expected_warps_per_group
146134
if threads_per_group <= 0:
147135
raise ValueError("threads_per_group must be positive.")
148-
ref = torch.arange(
149-
num_threads, dtype=A.dtype, device=A.device
150-
) // threads_per_group
136+
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group
151137
torch.testing.assert_close(A.cpu(), ref.cpu())
152138
return A
153139

154140

155-
def run_shuffle_elect(
156-
num_threads: int = 128, thread_extent: int = 64
157-
):
141+
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
158142
if thread_extent < 0:
159143
raise ValueError("thread_extent must be non-negative.")
160144
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
161145
A = kernel()
162146
print(kernel.get_kernel_source())
163147
print(A)
164-
indices = torch.arange(
165-
num_threads, device=A.device, dtype=torch.int64
166-
)
148+
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
167149
if thread_extent == 0:
168150
mask = indices == 0
169151
elif thread_extent > 0:
@@ -224,6 +206,7 @@ def test_shuffle_elect_default():
224206
def test_shuffle_elect_block_leader():
225207
run_shuffle_elect(num_threads=128, thread_extent=0)
226208

209+
227210
if __name__ == "__main__":
228211
tilelang.testing.main()
229212
# run_get_lane_id()

0 commit comments

Comments
 (0)