Skip to content

Commit 3f40b23

Browse files
committed
lint
1 parent ac7fe0b commit 3f40b23

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

testing/python/jit/test_tilelang_jit_nullptr.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tilelang.language as T
66
from tilelang.utils import map_torch_type
77

8+
89
@tl.jit
910
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
1011

@@ -46,16 +47,17 @@ def main(
4647

4748
return main
4849

50+
4951
@tl.jit
5052
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
5153

5254
@T.prim_func
5355
def main(
54-
A: T.Tensor((M, K), dtype),
55-
B: T.Tensor((K, N), dtype),
56-
C: T.Tensor((M, N), accum_dtype),
57-
Bias: T.Tensor((N), accum_dtype),
58-
with_bias: T.bool,
56+
A: T.Tensor((M, K), dtype),
57+
B: T.Tensor((K, N), dtype),
58+
C: T.Tensor((M, N), accum_dtype),
59+
Bias: T.Tensor((N), accum_dtype),
60+
with_bias: T.bool,
5961
):
6062
# Initialize Kernel Context
6163
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
@@ -79,6 +81,7 @@ def main(
7981

8082
return main
8183

84+
8285
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
8386
func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
8487

@@ -104,8 +107,10 @@ def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="f
104107
func(a, b, c, d, True)
105108
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
106109

110+
107111
def test_nullptr():
108112
run_test(1024, 1024, 1024, 128, 128, 32)
109113

114+
110115
if __name__ == "__main__":
111-
tilelang.testing.main()
116+
tilelang.testing.main()

0 commit comments

Comments
 (0)