Skip to content

Commit 0e236a0

Browse files
committed
fix ci and lint
1 parent e8fa8bb commit 0e236a0

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def flashattn_fwd(
3232
groups=1,
3333
window_size=None, # None for full attention
3434
sm_scale=None,
35-
block_M=128,
36-
block_N=128,
37-
num_stages=2,
38-
threads=256,
35+
block_M=64,
36+
block_N=64,
37+
num_stages=1,
38+
threads=128,
3939
dtype: str = "float16"):
4040

4141
if window_size is not None:

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def flashattn_fwd(
3131
dim,
3232
window_size=None, # None for full attention,
3333
sm_scale=None,
34-
block_M=128,
35-
block_N=32,
36-
num_stages=2,
37-
threads=256,
34+
block_M=64,
35+
block_N=64,
36+
num_stages=1,
37+
threads=128,
3838
dtype: str = "float16"):
3939

4040
if window_size is not None:
@@ -356,11 +356,8 @@ class _attention(torch.autograd.Function):
356356
@staticmethod
357357
def forward(ctx, q, k, v, sinks, window_size):
358358
BATCH, H, N_CTX, D_HEAD = q.shape
359-
block_M = 64
360-
block_N = 64 if D_HEAD <= 128 else 32
361359
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
362-
kernel = flashattn_fwd(
363-
BATCH, H, N_CTX, D_HEAD, window_size, block_M=block_M, block_N=block_N, dtype=dtype)
360+
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
364361
o, lse = kernel(q, k, v, sinks)
365362
ctx.save_for_backward(q, k, v, sinks, o, lse)
366363
ctx.window_size = window_size

0 commit comments

Comments
 (0)