Skip to content

Commit 09ed919

Browse files
committed
[Lint]
1 parent b86874b commit 09ed919

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

examples/dequantize_gemm/example_dequant_gemm_mxfp4_hopper_serial.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def matmul(M,
6969
scale_size=32,
7070
tune=False):
7171

72-
@tilelang.jit(
73-
out_idx=[-1],
74-
)
72+
@tilelang.jit(out_idx=[-1],)
7573
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
7674
num_elems_per_byte = 8 // num_bits
7775
storage_dtype = "uint8"
@@ -81,7 +79,6 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
8179
A_shared_shape = (block_M, block_K)
8280
B_shared_shape = (block_N, block_K // num_elems_per_byte)
8381
B_dequantize_shared_shape = (block_N, block_K)
84-
Scale_shared_shape = (block_N, block_K // scale_size)
8582
assert K % (block_K * split) == 0
8683

8784
# Some variables for serial dequant in each thread
@@ -121,7 +118,6 @@ def main(
121118
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
122119
B_dequantize_local_thread = T.alloc_local((local_size,), in_dtype)
123120
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
124-
Scale_shared = T.alloc_shared(Scale_shared_shape, storage_dtype)
125121
Scale_local_thread = T.alloc_local((1,), storage_dtype)
126122
Scale_local_thread_exponent = T.alloc_local((1,), "float32")
127123

@@ -158,7 +154,8 @@ def main(
158154
index_scale = index_base // (scale_size // num_elems_per_byte)
159155
si = index_scale // (block_K // scale_size)
160156
sj = index_scale % (block_K // scale_size)
161-
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
157+
Scale_local_thread[0] = Scale[bx * block_N + si,
158+
k * block_K // scale_size + sj]
162159
Scale_local_thread_exponent[0] = T.exp2(
163160
T.cast(Scale_local_thread[0] - 127, "float"))
164161

0 commit comments

Comments
 (0)