Skip to content

Commit 1d1f48e

Browse files
committed
fix previous typos
1 parent 985931d commit 1d1f48e

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None): # None
211211
block_M, block_N, num_stages, threads = get_bwd_configs()
212212

213213
if window_size is not None:
214-
assert window_size % block_N == 0, "window_size must be divisible by block_M"
214+
assert window_size % block_N == 0, "window_size must be divisible by block_N"
215215

216216
@T.prim_func
217217
def flash_bwd(

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def flashattn_bwd(
216216
accum_dtype = "float"
217217

218218
if window_size is not None:
219-
assert window_size % block_N == 0, "window_size must be divisible by block_M"
219+
assert window_size % block_N == 0, "window_size must be divisible by block_N"
220220

221221
@T.prim_func
222222
def flash_bwd(

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def flash_bwd_post(
146146
@tilelang.jit(pass_configs={
147147
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
148148
})
149-
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
149+
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
150150
sm_scale = (1.0 / dim)**0.5
151151
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
152152
shape = [batch, seq_len, heads, dim]
@@ -198,7 +198,7 @@ def flash_bwd(
198198
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
199199
T.clear(dv)
200200
T.clear(dk)
201-
loop_st = T.floordiv(by * block_M, block_N) if is_casual else 0
201+
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
202202
loop_ed = T.ceildiv(seq_len, block_N)
203203
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
204204
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
@@ -219,7 +219,7 @@ def flash_bwd(
219219
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
220220
for i, j in T.Parallel(block_M, block_N):
221221
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
222-
if is_casual:
222+
if is_causal:
223223
for i, j in T.Parallel(block_M, block_N):
224224
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
225225
0)

0 commit comments

Comments
 (0)