Skip to content

Commit ac32bc1

Browse files
committed
Update attention sink examples to use 32 heads
- Modified the `heads` parameter in both `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py` and `example_mha_sink_fwd_bhsd_wgmma_pipelined.py` from 1 to 32 to enhance performance in attention mechanisms. - Ensured consistency across example scripts for improved usability and testing.
1 parent e11b4c4 commit ac32bc1

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def gen_inputs(B, H, Sq, Skv, D,
366366

367367
def main(
368368
batch: int = 1,
369-
heads: int = 1,
369+
heads: int = 32,
370370
seq_q: int = 256,
371371
seq_kv: int = 256,
372372
dim: int = 128,

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tens
355355

356356

357357
def main(batch: int = 1,
358-
heads: int = 1,
358+
heads: int = 32,
359359
seq_q: int = 256,
360360
seq_kv: int = 256,
361361
dim: int = 128,

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
2020
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
2121
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
2222
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
23+
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
2324
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
2425
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
2526
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);

0 commit comments

Comments
 (0)