Skip to content

Commit

Permalink
[inductor] Use symbolic_hint when bounding fallback size hint (pytorc…
Browse files Browse the repository at this point in the history
…h#127262)

The previous fallback ignores any known hint values in the expression and only
looks at the value ranges. By using the `symbolic_hint` we will use both hints
and value ranges.

Also removed the recursive use of `size_hint` on the bounds, since these should
always be constants.

Pull Request resolved: pytorch#127262
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#127251
  • Loading branch information
peterbell10 authored and pytorchmergebot committed May 28, 2024
1 parent 26a8fa3 commit f6ef832
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,9 @@ def remove_precomputed_replacements(self, expr: Expr) -> Expr:
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
return expr

def symbolic_hint(self, expr: Expr) -> Expr:
def symbolic_hint(self, expr: Expr) -> Union[Expr, int]:
# Substitute all hints into expr, but leave unbacked symints alone
expr = self.simplify(expr)
if not isinstance(expr, Expr):
assert isinstance(expr, int)
return expr
Expand All @@ -435,19 +436,20 @@ def symbolic_hint(self, expr: Expr) -> Expr:
return sympy_subs(expr, self.var_to_val)

def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
expr = self.simplify(expr)
out = self.symbolic_hint(expr)
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
# Use the provided heuristic fallback hint
sym_vrs = {
s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols
unbacked_sym_vrs = {
s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols
}
if all(vr is not None for vr in sym_vrs.values()):
expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type]
lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type]
upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type]
fallback = min(max(fallback, lower), upper)
if all(vr is not None for vr in unbacked_sym_vrs.values()):
hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type]
if isinstance(hint_vr.lower, (int, sympy.Integer)):
fallback = max(fallback, int(hint_vr.lower))
if isinstance(hint_vr.upper, (int, sympy.Integer)):
fallback = min(fallback, int(hint_vr.upper))
return fallback

try:
return int(out)
except Exception:
Expand Down

0 comments on commit f6ef832

Please sign in to comment.