Skip to content

Commit c856ced

Browse files
committed
[BugFix] Add smem swizzle to recover performance of TMA
1 parent 0149177 commit c856ced

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def get_configs():
3333
iter_params = dict(
3434
block_M=[64, 128, 256],
3535
block_N=[64, 128, 256],
36-
block_K=[128],
36+
block_K=[64, 128, 256],
3737
num_stages=[0, 2],
38-
threads=[128, 256, 512],
38+
threads=[128, 256],
3939
split=[1, 2],
4040
)
4141
return [{
@@ -186,6 +186,8 @@ def main(
186186
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
187187

188188
T.annotate_layout({
189+
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
190+
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
189191
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
190192
})
191193

0 commit comments

Comments
 (0)