Skip to content

Commit e905328

Browse files
committed
lint fix
1 parent fb68698 commit e905328

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

testing/python/language/test_tilelang_language_vectorize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import tilelang.language as T
66

77

8+
@tilelang.jit
89
def vectorize_test(N, M, stride_A, stride_B):
910
assert N % 128 == 0 and M % 128 == 0
1011

1112
@T.prim_func
1213
def main(
13-
A: T.StridedTensor[(N, M), (1, stride_A), "float32"],
14-
B: T.StridedTensor[(N, M), (1, stride_B), "float32"],
14+
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
15+
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
1516
):
1617
with T.Kernel(M // 128, threads=128) as (bx):
1718
tx = T.get_thread_binding(0)
@@ -26,8 +27,7 @@ def main(
2627
def run_vectorize(N, M, stride_A, stride_B):
2728
assert stride_A >= N and stride_B >= N
2829

29-
program = vectorize_test(N, M, stride_A, stride_B)
30-
jit_kernel = tl.compile(program, target="cuda", execution_backend="cython")
30+
jit_kernel = vectorize_test(N, M, stride_A, stride_B)
3131

3232
base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32)
3333
base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32)

0 commit comments

Comments
 (0)