|
15 | 15 | import triton
|
16 | 16 | import triton.language as tl
|
17 | 17 |
|
| 18 | +def triton_autotune_configs(): |
| 19 | + # Return configs with a valid warp count for the current device |
| 20 | + configs=[] |
| 21 | + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 |
| 22 | + max_threads_per_block=1024 |
| 23 | + # Default to warp size 32 if not defined by device |
| 24 | + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) |
| 25 | + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit |
| 26 | + warp_count=1 |
| 27 | + while warp_count*warp_size <= max_threads_per_block: |
| 28 | + configs.append(triton.Config({}, num_warps=warp_count)) |
| 29 | + warp_count*=2 |
| 30 | + return configs |
18 | 31 |
|
19 | 32 | def layer_norm_ref(
|
20 | 33 | x,
|
@@ -126,14 +139,7 @@ def rms_norm_ref(
|
126 | 139 |
|
127 | 140 |
|
128 | 141 | @triton.autotune(
|
129 |
| - configs=[ |
130 |
| - triton.Config({}, num_warps=1), |
131 |
| - triton.Config({}, num_warps=2), |
132 |
| - triton.Config({}, num_warps=4), |
133 |
| - triton.Config({}, num_warps=8), |
134 |
| - triton.Config({}, num_warps=16), |
135 |
| - triton.Config({}, num_warps=32), |
136 |
| - ], |
| 142 | + configs=triton_autotune_configs(), |
137 | 143 | key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
138 | 144 | )
|
139 | 145 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
@@ -393,14 +399,7 @@ def _layer_norm_fwd(
|
393 | 399 |
|
394 | 400 |
|
395 | 401 | @triton.autotune(
|
396 |
| - configs=[ |
397 |
| - triton.Config({}, num_warps=1), |
398 |
| - triton.Config({}, num_warps=2), |
399 |
| - triton.Config({}, num_warps=4), |
400 |
| - triton.Config({}, num_warps=8), |
401 |
| - triton.Config({}, num_warps=16), |
402 |
| - triton.Config({}, num_warps=32), |
403 |
| - ], |
| 402 | + configs=triton_autotune_configs(), |
404 | 403 | key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
405 | 404 | )
|
406 | 405 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
0 commit comments