forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AutoHeuristic: mixed_mm heuristic for A100 (pytorch#131613)
This PR introduces changes to AutoHeuristic that allow one to learn a heuristic as a decision tree. I used this to learn a heuristic for mixed_mm on A100 that consistenly performs better than the default choice (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L402). This is how the results look like: Explanation of columns: **wrong_max_spdup**: In the worst case, how much better would the best choice have been **wrong_gman_spdup**: For inputs where the heuristic is wrong, how much better is the best choice on average (geomean) **max_spdup_default**: Highest speedup achieved by the learned heuristic over the default choice **gman_spdup_default**: Geomean speedup achived by the learned heuristic over the default choice **max_slowdown_default**: If the default choice is better than the choice predicted by the learned heuristic, how much is it better in the worst case **non_default_preds**: Number of times the learned heuristic predicted a choice that is not the default choice **default_better**: Number of times the default choice is better than the choice made by the heuristic ``` set crit max_depth min_samples_leaf correct wrong unsure total wrong_max_spdup wrong_gman_spdup max_spdup_default gman_spdup_default max_slowdown_default non_default_preds default_better train entropy 5 0.01 2376 740 323 3439 1.855386 1.063236 11.352318 3.438279 1.022164 3116 2 test entropy 5 0.01 563 183 71 817 1.622222 1.060897 10.084181 3.507741 1.017039 746 2 ``` While the number of wrong predictions is high, on average the best choice is only around 6% better. What is important is that the choice predicted by the learned heuristic performs better than the default choice. I evaluated my heuristic on gpt-fast `meta-llama/Llama-2-7b-chat-hf` with int8 weight quantization. To get the `tuned_mixed_mm` to trigger, I had to replace `F.linear()` in https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py#L355 with `torch.matmul(input, self.weight.t().to(dtype=input.dtype))` because the mixed_mm pattern does not match if there is a transpose between a cast and the matmul. |batch size|prompt length| fallback | heuristic | speedup | |----------|-------------|------------:|------------:|--------:| | 1 | 7 | 75.31 tok/s | 148.83 tok/s| 1.97 | | 1 | 11 | 75.99 tok/s | 148.15 tok/s| 1.94 | | 4 | 7 | 103.48 tok/s | 472.00 tok/s| 4.56 | | 4 | 11 | 103.56 tok/s | 371.36 tok/s| 3.58 | | 8 | 7 | 201.92 tok/s | 813.44 tok/s| 4.02 | | 8 | 11 | 201.76 tok/s | 699.36 tok/s| 3.46 | Currently, the heuristic only applies to the following inputs: - m <= 128, k >= 1024, n >= 1024 (For these sizes, one of the triton kernels wins in most cases, but the heuristic still has to be careful to not choose a config that performs worse than the fallback) - k % 256 == 0 (If k is not a multiple of the block size, some choices perform extremely bad. In one case one config, that usually performs very well, was 130x slower.) - mat1 not transposed - mat2 transposed (In some cases, it was hard for the learned heuristic to detect some cases where it Pull Request resolved: pytorch#131613 Approved by: https://github.com/eellison
- Loading branch information
1 parent
b9cb1ab
commit 4892918
Showing
22 changed files
with
1,587 additions
and
398 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# flake8: noqa: B950 | ||
# fmt: off | ||
# This file was generated by AutoHeuristic. Do not modify it manually! | ||
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ | ||
from typing import List, Optional, Tuple | ||
|
||
from torch._inductor.autoheuristic.autoheuristic_utils import ( | ||
AHContext, | ||
AHMetadata, | ||
Choice, | ||
) | ||
from torch._inductor.autoheuristic.learnedheuristic_interface import ( | ||
LearnedHeuristicDecision, | ||
) | ||
|
||
|
||
class MixedMMA100(LearnedHeuristicDecision): | ||
|
||
def __init__(self) -> None: | ||
self.choices: List[Choice] = [] | ||
self.fill_choices() | ||
|
||
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: | ||
return ( | ||
metadata.name == self.get_name() | ||
and metadata.shared_memory == 166912 | ||
and str(metadata.device_capa) == "(8, 0)" | ||
) | ||
|
||
def get_confidence_threshold(self) -> float: | ||
return 0.0 | ||
|
||
def get_choice(self, idx: int) -> Optional[str]: | ||
if idx < len(self.choices): | ||
return self.choices[idx] | ||
return None | ||
|
||
def fill_choices(self) -> None: | ||
self.choices.append('extern_fallback_mixed_mm') | ||
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') | ||
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') | ||
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') | ||
|
||
def get_name(self) -> str: | ||
return 'mixed_mm' | ||
|
||
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: | ||
if str(context.get_value('1LEQmLEQ16')) != 'True': | ||
if context.get_value('m') <= 32.5: | ||
if context.get_value('n') <= 6976.0: | ||
if context.get_value('n') <= 3520.0: | ||
if context.get_value('m*n') <= 37632.0: | ||
return None | ||
else: | ||
return [(1.000, 13)] | ||
else: | ||
if context.get_value('m*k') <= 452352.0: | ||
return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] | ||
else: | ||
return [(0.778, 8), (0.222, 13)] | ||
else: | ||
if context.get_value('k*n') <= 102776832.0: | ||
if context.get_value('n') <= 14656.0: | ||
return [(1.000, 11)] | ||
else: | ||
return [(0.889, 11), (0.111, 13)] | ||
else: | ||
return [(1.000, 11)] | ||
else: | ||
if context.get_value('m*n') <= 446464.0: | ||
if context.get_value('m*n') <= 223424.0: | ||
if context.get_value('mat1_stride_0') <= 3968.0: | ||
return None | ||
else: | ||
return None | ||
else: | ||
if context.get_value('m*n') <= 346112.0: | ||
return [(0.960, 16), (0.040, 7)] | ||
else: | ||
return [(0.750, 16), (0.136, 14), (0.114, 7)] | ||
else: | ||
if str(context.get_value('33LEQmLEQ64')) != 'True': | ||
if context.get_value('n') <= 6976.0: | ||
return [(1.000, 14)] | ||
else: | ||
return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] | ||
else: | ||
if context.get_value('n') <= 13888.0: | ||
return [(0.710, 14), (0.275, 21), (0.014, 12)] | ||
else: | ||
return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] | ||
else: | ||
if context.get_value('n') <= 3520.0: | ||
if context.get_value('arith_intensity') <= 3.994754433631897: | ||
if str(context.get_value('mat2_dtype')) != 'torch.uint8': | ||
if context.get_value('m*k') <= 18944.0: | ||
return [(0.577, 5), (0.423, 6)] | ||
else: | ||
return [(0.988, 5), (0.012, 6)] | ||
else: | ||
if context.get_value('arith_intensity') <= 2.9899919033050537: | ||
return None | ||
else: | ||
return None | ||
else: | ||
if context.get_value('arith_intensity') <= 7.956453561782837: | ||
if context.get_value('k*n') <= 9244032.0: | ||
return [(0.822, 5), (0.178, 6)] | ||
else: | ||
return [(0.977, 5), (0.023, 0)] | ||
else: | ||
if context.get_value('m*k') <= 978944.0: | ||
return [(1.000, 5)] | ||
else: | ||
return [(0.971, 5), (0.029, 0)] | ||
else: | ||
if context.get_value('n') <= 13632.0: | ||
if context.get_value('n') <= 6976.0: | ||
return [(1.000, 6)] | ||
else: | ||
if context.get_value('k') <= 3968.0: | ||
return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] | ||
else: | ||
return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] | ||
else: | ||
if context.get_value('k*n') <= 39518208.0: | ||
return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] | ||
else: | ||
if context.get_value('n') <= 20800.0: | ||
return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] | ||
else: | ||
return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.