Skip to content

Commit

Permalink
AutoHeuristic: mixed_mm heuristic for A100 (pytorch#131613)
Browse files Browse the repository at this point in the history
This PR introduces changes to AutoHeuristic that allow one to learn a heuristic as a decision tree. I used this to learn a heuristic for mixed_mm on A100 that consistenly performs better than the default choice (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py#L402).

This is how the results look like:
Explanation of columns:
**wrong_max_spdup**: In the worst case, how much better would the best choice have been
**wrong_gman_spdup**: For inputs where the heuristic is wrong, how much better is the best choice on average (geomean)
**max_spdup_default**: Highest speedup achieved by the learned heuristic over the default choice
**gman_spdup_default**: Geomean speedup achived by the learned heuristic over the default choice
**max_slowdown_default**: If the default choice is better than the choice predicted by the learned heuristic, how much is it better in the worst case
**non_default_preds**: Number of times the learned heuristic predicted a choice that is not the default choice
**default_better**: Number of times the default choice is better than the choice made by the heuristic
```
  set     crit  max_depth  min_samples_leaf  correct  wrong  unsure  total  wrong_max_spdup  wrong_gman_spdup    max_spdup_default  gman_spdup_default  max_slowdown_default  non_default_preds  default_better
train  entropy          5              0.01     2376    740     323   3439         1.855386          1.063236            11.352318            3.438279              1.022164               3116               2
 test  entropy          5              0.01      563    183      71    817         1.622222          1.060897            10.084181            3.507741              1.017039                746               2
```

While the number of wrong predictions is high, on average the best choice is only around 6% better. What is important is that the choice predicted by the learned heuristic performs better than the default choice.

I evaluated my heuristic on gpt-fast `meta-llama/Llama-2-7b-chat-hf` with int8 weight quantization. To get the `tuned_mixed_mm` to trigger, I had to replace `F.linear()` in https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py#L355 with `torch.matmul(input, self.weight.t().to(dtype=input.dtype))` because the mixed_mm pattern does not match if there is a transpose between a cast and the matmul.
|batch size|prompt length| fallback    |  heuristic  | speedup |
|----------|-------------|------------:|------------:|--------:|
|     1    |      7      | 75.31 tok/s | 148.83 tok/s|  1.97   |
|     1    |     11      | 75.99 tok/s | 148.15 tok/s|  1.94   |
|     4    |      7      | 103.48 tok/s | 472.00 tok/s|  4.56   |
|     4    |     11      | 103.56 tok/s |  371.36 tok/s|  3.58   |
|     8    |      7      | 201.92 tok/s | 813.44 tok/s|  4.02   |
|     8    |     11      | 201.76 tok/s |  699.36 tok/s|  3.46   |

Currently, the heuristic only applies to the following inputs:
- m <= 128, k >= 1024, n >= 1024 (For these sizes, one of the triton kernels wins in most cases, but the heuristic still has to be careful to not choose a config that performs worse than the fallback)
- k % 256 == 0 (If k is not a multiple of the block size, some choices perform extremely bad. In one case one config, that usually performs very well, was 130x slower.)
- mat1 not transposed
- mat2 transposed (In some cases, it was hard for the learned heuristic to detect some cases where it

Pull Request resolved: pytorch#131613
Approved by: https://github.com/eellison
  • Loading branch information
AlnisM authored and pytorchmergebot committed Aug 2, 2024
1 parent b9cb1ab commit 4892918
Show file tree
Hide file tree
Showing 22 changed files with 1,587 additions and 398 deletions.
3 changes: 3 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ exclude_patterns = [
'functorch/examples/**',
'functorch/notebooks/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/autoheuristic/artifacts/**',
'scripts/**',
'test/generated_type_hints_smoketest.py',
# Tests from the NumPy test suite
Expand Down Expand Up @@ -1029,6 +1030,7 @@ exclude_patterns = [
'third_party/**/*.py',
'third_party/**/*.pyi',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/autoheuristic/artifacts/**',
# These files are all grandfathered in, feel free to remove from this list
# as necessary
'test/_nvfuser/__init__.py',
Expand Down Expand Up @@ -1540,6 +1542,7 @@ exclude_patterns = [
'functorch/docs/**',
'functorch/notebooks/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/autoheuristic/artifacts/**',
'scripts/**',
'third_party/**',
'fb/**',
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ select = [
]
# autogenerated #TODO figure out why file level noqa is ignored
"torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"]
"torch/_inductor/autoheuristic/artifacts/**" = ["F401", "F501"]
"torchgen/api/types/__init__.py" = [
"F401",
"F403",
Expand Down
22 changes: 18 additions & 4 deletions test/inductor/test_autoheuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,22 @@ def test_autoheuristic_h100(self):
# TODO (AlnisM): Find a way to check whether heuristic is used
self.run_mm()

@inductor_config.patch(autoheuristic_collect="mixed_mm")
def test_global_feedback(self):
def run_mixed_mm(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))

a = torch.randn(8, 8, device="cuda")
b = torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda")
a = torch.randn(8, 1024, device="cuda", dtype=torch.float16)
b = torch.randint(-128, 127, (1024, 1024), dtype=torch.int8, device="cuda").t()
torch.compile(fn, mode="max-autotune-no-cudagraphs")(a, b)

# have to set autoheuristic_use="" because if autoheuristic_use="mixed_mm",
# autoheuristic creates a precompile key, puts it into the registry, and then
# a choice made by the heuristic might be added to the list of choices
# and if select_algorithm now creates a new precompile key, it will be
# different from the precompile key created by autoheuristic
@inductor_config.patch(autoheuristic_collect="mixed_mm", autoheuristic_use="")
def test_global_feedback(self):
self.run_mixed_mm()
path = self.get_path_to_autoheuristic_log("mixed_mm")
self.assertTrue(os.path.exists(path))
num_lines = self.count_lines_in_file(path)
Expand All @@ -143,6 +151,12 @@ def fn(a, b):
# 1 line for fallback + at least 1 config
self.assertTrue(num_lines > 4)

@inductor_config.patch(autoheuristic_use="mixed_mm")
@unittest.skipIf(not IS_A100, "heuristic only run on A100")
def test_mixed_mm_a100(self):
self.run_mixed_mm()
# TODO (AlnisM): Find a way to check whether heuristic is used


if __name__ == "__main__":
if HAS_CUDA:
Expand Down
6 changes: 3 additions & 3 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ def fn(a, b, c, d):

@unittest.skipIf(not SM80OrLater, "need sm_80")
@unittest.skipIf(not IS_A100, "heuristic only run on Linux A100")
@inductor_config.patch(mixed_mm_choice="heuristic")
@inductor_config.patch(mixed_mm_choice="heuristic", autoheuristic_use="")
def test_mixed_mm_heuristic_no(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))

# examples that should not be selected by heuristic
# examples that should not be selected by handwritten heuristic
mat1_dtype = torch.float16
dyn_tensor = torch.randn(4, 4096, dtype=mat1_dtype, device="cuda")
torch._dynamo.mark_dynamic(dyn_tensor, 0)
Expand Down Expand Up @@ -336,7 +336,7 @@ def fn(a, b):
return torch.mm(a, b.to(a.dtype))

mat1_dtype = torch.float16
# examples that should be selected by heuristic
# examples that should be selected by handwritten heuristic
args_list = [
(
torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"),
Expand Down
150 changes: 150 additions & 0 deletions torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
from typing import List, Optional, Tuple

from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicDecision,
)


class MixedMMA100(LearnedHeuristicDecision):

def __init__(self) -> None:
self.choices: List[Choice] = []
self.fill_choices()

def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
return (
metadata.name == self.get_name()
and metadata.shared_memory == 166912
and str(metadata.device_capa) == "(8, 0)"
)

def get_confidence_threshold(self) -> float:
return 0.0

def get_choice(self, idx: int) -> Optional[str]:
if idx < len(self.choices):
return self.choices[idx]
return None

def fill_choices(self) -> None:
self.choices.append('extern_fallback_mixed_mm')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')

def get_name(self) -> str:
return 'mixed_mm'

def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
if str(context.get_value('1LEQmLEQ16')) != 'True':
if context.get_value('m') <= 32.5:
if context.get_value('n') <= 6976.0:
if context.get_value('n') <= 3520.0:
if context.get_value('m*n') <= 37632.0:
return None
else:
return [(1.000, 13)]
else:
if context.get_value('m*k') <= 452352.0:
return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)]
else:
return [(0.778, 8), (0.222, 13)]
else:
if context.get_value('k*n') <= 102776832.0:
if context.get_value('n') <= 14656.0:
return [(1.000, 11)]
else:
return [(0.889, 11), (0.111, 13)]
else:
return [(1.000, 11)]
else:
if context.get_value('m*n') <= 446464.0:
if context.get_value('m*n') <= 223424.0:
if context.get_value('mat1_stride_0') <= 3968.0:
return None
else:
return None
else:
if context.get_value('m*n') <= 346112.0:
return [(0.960, 16), (0.040, 7)]
else:
return [(0.750, 16), (0.136, 14), (0.114, 7)]
else:
if str(context.get_value('33LEQmLEQ64')) != 'True':
if context.get_value('n') <= 6976.0:
return [(1.000, 14)]
else:
return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)]
else:
if context.get_value('n') <= 13888.0:
return [(0.710, 14), (0.275, 21), (0.014, 12)]
else:
return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)]
else:
if context.get_value('n') <= 3520.0:
if context.get_value('arith_intensity') <= 3.994754433631897:
if str(context.get_value('mat2_dtype')) != 'torch.uint8':
if context.get_value('m*k') <= 18944.0:
return [(0.577, 5), (0.423, 6)]
else:
return [(0.988, 5), (0.012, 6)]
else:
if context.get_value('arith_intensity') <= 2.9899919033050537:
return None
else:
return None
else:
if context.get_value('arith_intensity') <= 7.956453561782837:
if context.get_value('k*n') <= 9244032.0:
return [(0.822, 5), (0.178, 6)]
else:
return [(0.977, 5), (0.023, 0)]
else:
if context.get_value('m*k') <= 978944.0:
return [(1.000, 5)]
else:
return [(0.971, 5), (0.029, 0)]
else:
if context.get_value('n') <= 13632.0:
if context.get_value('n') <= 6976.0:
return [(1.000, 6)]
else:
if context.get_value('k') <= 3968.0:
return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)]
else:
return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)]
else:
if context.get_value('k*n') <= 39518208.0:
return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)]
else:
if context.get_value('n') <= 20800.0:
return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)]
else:
return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)]
8 changes: 5 additions & 3 deletions torch/_inductor/autoheuristic/artifacts/_PadMMA100.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
Choice,
CHOICE_COL,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic
from torch._inductor.autoheuristic.learnedheuristic_interface import (
LearnedHeuristicRegression,
)


class PadMMA100(LearnedHeuristic):
class PadMMA100(LearnedHeuristicRegression):
def __init__(self) -> None:
pass

Expand All @@ -28,7 +30,7 @@ def get_feedback(self, context: AHContext, choice: Choice) -> float:
context.context_dict[CHOICE_COL] = choice
return self.predict(context)

def get_speedup_threshold(self) -> float:
def get_confidence_threshold(self) -> float:
return 1.7025303314066

def get_name(self) -> str:
Expand Down
25 changes: 5 additions & 20 deletions torch/_inductor/autoheuristic/autoheuristic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._inductor.autoheuristic.autoheuristic_utils import (
Expand All @@ -12,6 +12,7 @@
CHOICE_COL,
Feedback,
FEEDBACK_COL,
get_metadata_str_from_log,
)
from torch._inductor.autoheuristic.learned_heuristic_controller import (
LearnedHeuristicController,
Expand All @@ -21,25 +22,6 @@
from torch._inductor.utils import get_gpu_shared_memory


def deserialize_data(log_path: str) -> Tuple[Any, Dict[str, Any]]:
json_string = get_metadata_str_from_log(log_path)
metadata = deserialize_metadata(json_string)
import pandas as pd # type: ignore[import-untyped]

df = pd.read_csv(log_path, skiprows=1)
return (df, metadata)


def deserialize_metadata(json_string: str) -> Dict[str, Any]:
return json.loads(json_string)


def get_metadata_str_from_log(log_path: str) -> str:
with open(log_path, newline="") as file:
json_string = file.readline().strip()
return json_string


class LocalFeedback:
"""
To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
Expand Down Expand Up @@ -147,6 +129,9 @@ def get_choice(self) -> Choice:
self.context,
)
decision = controller.get_decision()
if decision not in self.choices:
# TODO(AlnisM): We might want to allow this in the future
return self.fallback()
if decision is not None:
return decision
return self.fallback()
Expand Down
Loading

0 comments on commit 4892918

Please sign in to comment.