Skip to content

Commit

Permalink
AutoHeuristic: Enable explicit support for ranking (pytorch#131710)
Browse files Browse the repository at this point in the history
This PR adds support for heuristics that rank choices in AutoHeuristic.

Pull Request resolved: pytorch#131710
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#131705
  • Loading branch information
AlnisM authored and pytorchmergebot committed Aug 16, 2024
1 parent add0f00 commit 3a904d1
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
20 changes: 20 additions & 0 deletions torch/_inductor/autoheuristic/autoheuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ def get_choice(self) -> Choice:
return decision
return self.fallback()

def get_top_k_choices(self, top_k: int) -> Optional[List[Choice]]:
if not self.satisfies_precondition():
return None
if torch._inductor.config.use_autoheuristic(self.name):
if self.augment_context is not None:
self.context.apply_operations(self.augment_context)
controller = LearnedHeuristicController(
self.metadata,
self.context,
)
choices = controller.get_decisions_ranked(top_k)
return choices
return None

def get_collected_feedback(self, choice: Choice) -> Any:
return self.collected_feedback.get(choice, None)

Expand Down Expand Up @@ -283,3 +297,9 @@ def store_global_feedback(
def get_choice_caller(self) -> Optional[ChoiceCaller]:
choice = self.get_choice()
return self.choicestr2choice.get(choice, None)

def get_top_k_choices_caller(self, top_k: int) -> Optional[List[ChoiceCaller]]:
choices = self.get_top_k_choices(top_k)
if choices is None:
return None
return [self.choicestr2choice[choice] for choice in choices]
13 changes: 13 additions & 0 deletions torch/_inductor/autoheuristic/learned_heuristic_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,16 @@ def get_decision(self) -> Optional[Choice]:
if heuristic.check_precondition(self.metadata, self.context):
return heuristic.get_decision(self.context, self.metadata.choices)
return None

def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
heuristics = self.get_heuristics(self.metadata.name)
for heuristic in heuristics:
if heuristic.check_precondition(self.metadata, self.context):
choices = heuristic.get_decisions_ranked(self.context)
if choices is None:
return None
avail_choices = [
choice for choice in choices if choice in self.metadata.choices
]
return avail_choices[:top_k]
return None
13 changes: 13 additions & 0 deletions torch/_inductor/autoheuristic/learnedheuristic_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def get_confidence_threshold(self) -> float:
def get_name(self) -> str:
return ""

def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
return None


class LearnedHeuristicRegression(LearnedHeuristic):
def __init__(self) -> None:
Expand Down Expand Up @@ -75,5 +78,15 @@ def get_decision(
return None
return self.get_choice(best_choice_idx)

def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
feedback_idx_list = self.get_best_choices(context)
if feedback_idx_list is None:
return None
choices = [
self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
]
choices = [choice for choice in choices if choice is not None]
return choices

def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
return []
32 changes: 27 additions & 5 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,23 @@ def tuned_mm(mat1, mat2, *, layout=None):
layout=layout,
**mm_options(config, m, n, k, layout),
)
choice = mm_autoheuristic(
mat1, mat2, m, n, k, choices, name, input_nodes, mm_operations(), None
ah_choices = mm_autoheuristic(
mat1,
mat2,
m,
n,
k,
choices,
name,
input_nodes,
mm_operations(),
None,
top_k=10,
)
if not torch._inductor.config.collect_autoheuristic(name):
# if we are collecting data, we do not want to modify choices
if choice is not None:
choices.insert(0, choice)
if ah_choices is not None and len(ah_choices) > 0:
choices = ah_choices
else:
choices = choices[:num_choices_before_extra_configs]

Expand Down Expand Up @@ -505,7 +515,17 @@ def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):


def mm_autoheuristic(
mat1, mat2, m, n, k, choices, name, input_nodes, ops, precondition
mat1,
mat2,
m,
n,
k,
choices,
name,
input_nodes,
ops,
precondition,
top_k: Optional[int] = None,
):
m, n, k = get_size_hints(mat1, mat2, m, n, k)
if not dims_are_int([m, n, k]):
Expand Down Expand Up @@ -544,6 +564,8 @@ def fallback():
augment_context=ops,
precondition=precondition,
)
if top_k is not None:
return autoheuristic.get_top_k_choices_caller(top_k)
return autoheuristic.get_choice_caller()


Expand Down

0 comments on commit 3a904d1

Please sign in to comment.