Skip to content

Commit 27f501d

Browse files
authored
Dynamic autotune configs for devices with warp size != 32 (Dao-AILab#1534)
Generate a list of autotune configs based on device warp size to avoid triton error if maximum threads per block is exceeded.
1 parent 4b5eeab commit 27f501d

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

flash_attn/ops/triton/layer_norm.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
import triton
1616
import triton.language as tl
1717

18+
def triton_autotune_configs():
19+
# Return configs with a valid warp count for the current device
20+
configs=[]
21+
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
22+
max_threads_per_block=1024
23+
# Default to warp size 32 if not defined by device
24+
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
25+
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
26+
warp_count=1
27+
while warp_count*warp_size <= max_threads_per_block:
28+
configs.append(triton.Config({}, num_warps=warp_count))
29+
warp_count*=2
30+
return configs
1831

1932
def layer_norm_ref(
2033
x,
@@ -126,14 +139,7 @@ def rms_norm_ref(
126139

127140

128141
@triton.autotune(
129-
configs=[
130-
triton.Config({}, num_warps=1),
131-
triton.Config({}, num_warps=2),
132-
triton.Config({}, num_warps=4),
133-
triton.Config({}, num_warps=8),
134-
triton.Config({}, num_warps=16),
135-
triton.Config({}, num_warps=32),
136-
],
142+
configs=triton_autotune_configs(),
137143
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
138144
)
139145
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@@ -393,14 +399,7 @@ def _layer_norm_fwd(
393399

394400

395401
@triton.autotune(
396-
configs=[
397-
triton.Config({}, num_warps=1),
398-
triton.Config({}, num_warps=2),
399-
triton.Config({}, num_warps=4),
400-
triton.Config({}, num_warps=8),
401-
triton.Config({}, num_warps=16),
402-
triton.Config({}, num_warps=32),
403-
],
402+
configs=triton_autotune_configs(),
404403
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
405404
)
406405
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})

0 commit comments

Comments
 (0)