Skip to content

Commit

Permalink
[BE] Consistently use the sym_stride lowering, instead of short-circu…
Browse files Browse the repository at this point in the history
…iting before (pytorch#113071)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#113071
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 10, 2023
1 parent 958f755 commit 9752ef5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
10 changes: 2 additions & 8 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,15 +737,9 @@ def debug(msg):
debug("layout_constraints")
args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
result = self.call_function(n.target, args, kwargs)
elif n.target == torch.ops.aten.sym_stride.int:
debug("sym_stride")
# inductor graphs can occasionally return sizes/strides,
# e.g. if we need to save symints for the backward graph.
if isinstance(n.meta["val"], torch.SymInt):
result = n.meta["val"].node.expr
else:
result = super().run_node(n)
elif is_magic_method(n.target):
# TODO: this is sus, it probably should be handled in the
# lowerings themselves similarly to sym_size/sym-stride
debug("is_magic_method")
if isinstance(n.meta["val"], torch.SymInt):
result = n.meta["val"].node.expr
Expand Down
19 changes: 17 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5039,12 +5039,27 @@ def sym_constrain_range(a, min, max):

@register_lowering(aten.sym_size.int)
def sym_size(a, dim):
return a.get_size()[dim]
val = V.graph.current_node.meta["val"]
# Note [Can val be an int?]
# ~~~~~~~~~~~~~~~~~~~~~~~~~
# In principle, someone could construct an FX graph where
# a call to size/stride has a val that is a plain int (not
# SymInt). However, we will maintain the invariant that
# this is not possible: if you are constructing an FX graph
# where there is a call to size/stride that returns an
# int, but you KNOW that int must always be a constant,
# then you do not need trace that call at all (and just
# constant propagate the integer as is.)
assert isinstance(val, torch.SymInt)
return val.node.expr


@register_lowering(aten.sym_stride.int)
def sym_stride(a, dim):
return a.get_stride()[dim]
val = V.graph.current_node.meta["val"]
# See Note [Can val be an int?]
assert isinstance(val, torch.SymInt)
return val.node.expr


@register_lowering(aten.sym_numel)
Expand Down

0 comments on commit 9752ef5

Please sign in to comment.