Skip to content

Commit e3a5dab

Browse files
committed
[Lint]
1 parent 73e5040 commit e3a5dab

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
300300
bx * block_N + i, k * block_K // scale_size + j //
301301
scale_size], # Scale is the exponential part, within the representation of uint8
302302
dtype=out_dtype,
303-
) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
303+
) * T.shift_left(
304+
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
304305
T.copy(B_dequantize_local, B_dequantize_shared)
305306

306307
return simple_dequant_bf16_fp4
@@ -426,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
426427

427428
if tune:
428429
kernel = matmul(
429-
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant)
430+
m,
431+
n,
432+
k,
433+
"bfloat16",
434+
"bfloat16",
435+
"float32",
436+
num_bits=4,
437+
scale_size=scale_size,
438+
fast_dequant=fast_dequant)
430439
else:
431440
kernel = matmul(
432441
m,

0 commit comments

Comments
 (0)