Skip to content

Commit c8a376c

Browse files
committed
fix comments and lint
1 parent 0e236a0 commit c8a376c

9 files changed

+29
-32
lines changed

examples/amd/example_amd_flash_attn_bwd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ def flash_bwd(
206206
T.clear(dq)
207207
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
208208
for i, j in T.Parallel(block_N, dim_qk):
209-
if k * block_N + i < seq_len:
210-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
209+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
211210

212211
for i, j in T.Parallel(block_M, dim_v):
213212
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tilelang.profiler import do_bench
66
import tilelang.language as T
77
import argparse
8+
from typing import Optional
89

910

1011
def get_bwd_configs():
@@ -23,7 +24,7 @@ def get_bwd_configs():
2324
pass_configs={
2425
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2526
},
26-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
27+
compile_flags=["-O3", "-DENABLE_BF16"])
2728
def flashattn_fwd(
2829
batch,
2930
heads,
@@ -143,7 +144,7 @@ def flash_fwd(
143144
pass_configs={
144145
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
145146
},
146-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
147+
compile_flags=["-O3", "-DENABLE_BF16"])
147148
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
148149
accum_dtype = "float"
149150
shape = [batch, heads, seq_len, dim]
@@ -183,7 +184,7 @@ def make_dq_layout(dQ):
183184
pass_configs={
184185
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
185186
},
186-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
187+
compile_flags=["-O3", "-DENABLE_BF16"])
187188
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
188189
accum_dtype = "float"
189190
shape = [batch, heads, seq_len, dim]
@@ -208,7 +209,7 @@ def flash_bwd_post(
208209
pass_configs={
209210
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
210211
},
211-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
212+
compile_flags=["-O3", "-DENABLE_BF16"])
212213
def flashattn_bwd(batch,
213214
heads,
214215
seq_len,
@@ -311,8 +312,7 @@ def flash_bwd(
311312
T.clear(dq)
312313
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
313314
for i, j in T.Parallel(block_N, dim):
314-
if k * block_N + i < seq_len:
315-
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
315+
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
316316

317317
T.copy(dv, dv_shared)
318318
T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared)
@@ -405,7 +405,7 @@ def ref_program(query: torch.Tensor,
405405
key: torch.Tensor,
406406
value: torch.Tensor,
407407
sinks: torch.Tensor,
408-
sliding_window: int | None = None,
408+
sliding_window: Optional[int] = None,
409409
dtype: torch.dtype = torch.float16) -> torch.Tensor:
410410

411411
key = key.transpose(1, 2).contiguous()

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton
1313
import triton.language as tl
1414
from triton.tools.tensor_descriptor import TensorDescriptor
15+
from typing import Optional
1516

1617

1718
def get_configs():
@@ -29,7 +30,7 @@ def get_configs():
2930
pass_configs={
3031
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
3132
},
32-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
33+
compile_flags=["-O3", "-DENABLE_BF16"])
3334
def flashattn(
3435
batch,
3536
heads,
@@ -211,7 +212,7 @@ def ref_program(query: torch.Tensor,
211212
key: torch.Tensor,
212213
value: torch.Tensor,
213214
sinks: torch.Tensor,
214-
sliding_window: int | None = None,
215+
sliding_window: Optional[int] = None,
215216
dtype: torch.dtype = torch.float16) -> torch.Tensor:
216217

217218
key = key.transpose(1, 2).contiguous()

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tilelang.profiler import do_bench
66
import tilelang.language as T
77
import argparse
8+
from typing import Optional
89

910

1011
def get_bwd_configs():
@@ -23,7 +24,7 @@ def get_bwd_configs():
2324
pass_configs={
2425
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2526
},
26-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
27+
compile_flags=["-O3", "-DENABLE_BF16"])
2728
def flashattn_fwd(
2829
batch,
2930
heads,
@@ -140,7 +141,7 @@ def flash_fwd(
140141
pass_configs={
141142
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
142143
},
143-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
144+
compile_flags=["-O3", "-DENABLE_BF16"])
144145
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
145146
accum_dtype = "float"
146147
shape = [batch, heads, seq_len, dim]
@@ -180,7 +181,7 @@ def make_dq_layout(dQ):
180181
pass_configs={
181182
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
182183
},
183-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
184+
compile_flags=["-O3", "-DENABLE_BF16"])
184185
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
185186
accum_dtype = "float"
186187
shape = [batch, heads, seq_len, dim]
@@ -205,7 +206,7 @@ def flash_bwd_post(
205206
pass_configs={
206207
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
207208
},
208-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
209+
compile_flags=["-O3", "-DENABLE_BF16"])
209210
def flashattn_bwd(
210211
batch,
211212
heads,
@@ -312,8 +313,7 @@ def flash_bwd(
312313
T.clear(dq)
313314
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
314315
for i, j in T.Parallel(block_N, dim):
315-
if k * block_N + i < seq_len:
316-
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
316+
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
317317
T.copy(dv, dv_shared)
318318
T.copy(dk, dk_shared)
319319
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])
@@ -400,7 +400,7 @@ def ref_program(query: torch.Tensor,
400400
key: torch.Tensor,
401401
value: torch.Tensor,
402402
sinks: torch.Tensor,
403-
sliding_window: int | None = None,
403+
sliding_window: Optional[int] = None,
404404
dtype: torch.dtype = torch.float16) -> torch.Tensor:
405405

406406
query = query.transpose(1, 2).contiguous().unsqueeze(

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tilelang.layout import make_swizzled_layout
99
import itertools
1010
import argparse
11+
from typing import Optional
1112

1213

1314
def get_configs():
@@ -21,7 +22,7 @@ def get_configs():
2122
pass_configs={
2223
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2324
},
24-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
25+
compile_flags=["-O3", "-DENABLE_BF16"])
2526
def flashattn(
2627
batch,
2728
heads,
@@ -191,7 +192,7 @@ def ref_program(query: torch.Tensor,
191192
key: torch.Tensor,
192193
value: torch.Tensor,
193194
sinks: torch.Tensor,
194-
sliding_window: int | None = None,
195+
sliding_window: Optional[int] = None,
195196
dtype: torch.dtype = torch.float16) -> torch.Tensor:
196197

197198
query = query.transpose(1, 2).contiguous().unsqueeze(

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton
1313
import triton.language as tl
1414
from triton.tools.tensor_descriptor import TensorDescriptor
15+
from typing import Optional
1516

1617

1718
def get_configs():
@@ -25,7 +26,7 @@ def get_configs():
2526
pass_configs={
2627
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2728
},
28-
compile_flags=["--use_fast_math", "-O3", "-DENABLE_BF16"])
29+
compile_flags=["-O3", "-DENABLE_BF16"])
2930
def flashattn(
3031
batch,
3132
heads,
@@ -204,7 +205,7 @@ def ref_program(query: torch.Tensor,
204205
key: torch.Tensor,
205206
value: torch.Tensor,
206207
sinks: torch.Tensor,
207-
sliding_window: int | None = None,
208+
sliding_window: Optional[int] = None,
208209
dtype: torch.dtype = torch.float16) -> torch.Tensor:
209210

210211
query = query.transpose(1, 2).contiguous().unsqueeze(

examples/flash_attention/example_gqa_bwd.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,7 @@ def flash_bwd(
235235
T.clear(dq)
236236
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
237237
for i, j in T.Parallel(block_N, dim_qk):
238-
if k * block_N + i < seq_len:
239-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
238+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
240239
T.copy(dv, dv_shared)
241240
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
242241
T.copy(dk, dk_shared)
@@ -340,8 +339,7 @@ def flash_bwd(
340339
T.clear(dq)
341340
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
342341
for i, j in T.Parallel(block_N, dim_qk):
343-
if k * block_N + i < seq_len:
344-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
342+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
345343

346344
T.copy(dv, dv_shared)
347345
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])

examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ def flash_bwd(
245245
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
246246
T.wait_wgmma(0)
247247
for i, j in T.Parallel(block_N, dim_qk):
248-
if k * block_N + i < seq_len:
249-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
248+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
250249
T.copy(dv, dv_shared)
251250
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
252251
T.copy(dk, dk_shared)
@@ -362,8 +361,7 @@ def flash_bwd(
362361
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
363362
T.wait_wgmma(0)
364363
for i, j in T.Parallel(block_N, dim_qk):
365-
if k * block_N + i < seq_len:
366-
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
364+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
367365

368366
T.copy(dv, dv_shared)
369367
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])

examples/flash_attention/example_mha_bwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,7 @@ def flash_bwd(
229229
T.clear(dq)
230230
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
231231
for i, j in T.Parallel(block_N, dim):
232-
if k * block_N + i < seq_len:
233-
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
232+
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
234233
T.copy(dv, dv_shared)
235234
T.copy(dk, dk_shared)
236235
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])

0 commit comments

Comments
 (0)