Skip to content

Commit 6713b45

Browse files
ydwu4pytorchmergebot
authored andcommitted
[hop][dynamo] support torch.SymInt inputs (pytorch#141524)
Fixes pytorch#141305. ```python class M(torch.nn.Module): def forward(self, x, y, z): a = y.shape[0] b = z.shape[0] def true_fn(x): return x + a def false_fn(x): return x + b * z # When exporting with non-strict: a and b are symints, # so torch.compile need to wrap and trace symint inputs. return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,)) ``` In non-strict export, when inputs are annotated with dynamic shape, the a, and b in above example are torch.SymInt type. true_fn and false_fn will have closure that're of torch.SymInt types. The error is triggered because we didn't handle SymInt inputs in dynamo and ends up using a UserDefinedObjectVariable for it, which doesn't have a proxy. We added support by following how we handle SymBool input previously. Pull Request resolved: pytorch#141524 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#141610, pytorch#142185
1 parent 7eda06b commit 6713b45

File tree

2 files changed

+160
-31
lines changed

2 files changed

+160
-31
lines changed

test/functorch/test_control_flow.py

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.utils._pytree as pytree
88
from functorch.experimental import control_flow
99
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
10+
from torch._dynamo.testing import normalize_gm
1011
from torch._higher_order_ops.associative_scan import (
1112
_fake_associative_scan,
1213
associative_scan,
@@ -6181,6 +6182,19 @@ def test_while_loop_schema_gen(self):
61816182
)
61826183
self.assertEqual(schema.parse(str(schema)), schema)
61836184

6185+
# Return the .module() graph str result of non-strict export
6186+
def _check_export(self, fn, args, dynamic_shapes=None) -> str:
6187+
strict_ep = torch.export.export(
6188+
fn, args, dynamic_shapes=dynamic_shapes, strict=True
6189+
)
6190+
non_strict_ep = torch.export.export(
6191+
fn, args, dynamic_shapes=dynamic_shapes, strict=False
6192+
)
6193+
eager_res = fn(*args)
6194+
self.assertEqual(strict_ep.module()(*args), eager_res)
6195+
self.assertEqual(non_strict_ep.module()(*args), eager_res)
6196+
return normalize_gm(non_strict_ep.module().print_readable(print_output=False))
6197+
61846198
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
61856199
@torch._dynamo.config.patch(capture_scalar_outputs=True)
61866200
def test_cond_eager_run_with_item(self):
@@ -6204,20 +6218,122 @@ def false_fn(x):
62046218
)
62056219
model = M()
62066220
ep = torch.export.export(model, args)
6221+
graph_str = self._check_export(model, args, None)
62076222
self.assertExpectedInline(
6208-
ep.module().code.strip(),
6223+
graph_str,
62096224
"""\
6210-
def forward(self, a, b1, b2, c):
6211-
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
6212-
true_graph_0 = self.true_graph_0
6213-
false_graph_0 = self.false_graph_0
6214-
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, [c, b1, b2]); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
6215-
getitem = cond[0]; cond = None
6216-
mul = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
6217-
return pytree.tree_unflatten((mul,), self._out_spec)""", # noqa: B950
6225+
class GraphModule(torch.nn.Module):
6226+
def forward(self, a, b1, b2, c):
6227+
a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
6228+
6229+
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
6230+
true_graph_0 = self.true_graph_0
6231+
false_graph_0 = self.false_graph_0
6232+
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, [c, b1, b2]); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
6233+
getitem: "f32[10]" = cond[0]; cond = None
6234+
6235+
mul: "f32[10]" = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
6236+
return pytree.tree_unflatten((mul,), self._out_spec)
6237+
6238+
class true_graph_0(torch.nn.Module):
6239+
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
6240+
item: "Sym(u0)" = torch.ops.aten.item.default(b1); b1 = None
6241+
6242+
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
6243+
return (mul,)
6244+
6245+
class false_graph_0(torch.nn.Module):
6246+
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
6247+
item: "Sym(u1)" = torch.ops.aten.item.default(b2); b2 = None
6248+
6249+
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
6250+
return (mul,)
6251+
""", # noqa: B950
6252+
)
6253+
6254+
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
6255+
def test_cond_symint_closure(self):
6256+
from torch.export import Dim
6257+
6258+
class M(torch.nn.Module):
6259+
def forward(self, x, y, z):
6260+
a = y.shape[0]
6261+
b = z.shape[0]
6262+
6263+
def true_fn(x):
6264+
return x + a
6265+
6266+
def false_fn(x):
6267+
return x + b * z
6268+
6269+
# When exporting with non-strict: a and b are symints,
6270+
# so torch.compile need to wrap and trace symint inputs.
6271+
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
6272+
6273+
args = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
6274+
model = M()
6275+
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
6276+
non_strict_graph_str = self._check_export(model, args, dynamic_shapes)
6277+
self.assertExpectedInline(
6278+
non_strict_graph_str,
6279+
"""\
6280+
class GraphModule(torch.nn.Module):
6281+
def forward(self, x, y, z):
6282+
x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]";
6283+
6284+
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
6285+
sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
6286+
sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None
6287+
6288+
gt: "Sym(s0 > 5)" = sym_size_int_3 > 5
6289+
6290+
true_graph_0 = self.true_graph_0
6291+
false_graph_0 = self.false_graph_0
6292+
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, sym_size_int_4, sym_size_int_3, z]); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
6293+
getitem: "f32[s0, 3]" = cond[0]; cond = None
6294+
return pytree.tree_unflatten((getitem,), self._out_spec)
6295+
6296+
class true_graph_0(torch.nn.Module):
6297+
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
6298+
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
6299+
return (add,)
6300+
6301+
class false_graph_0(torch.nn.Module):
6302+
def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
6303+
mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
6304+
6305+
add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
6306+
return (add,)
6307+
""", # noqa: B950
62186308
)
6219-
expected_output = model(*args)
6220-
self.assertEqual(expected_output, x * 3 * 2)
6309+
6310+
# unbacked symint inputs are created during non-strict export,
6311+
# which causes a graph break
6312+
@unittest.expectedFailure
6313+
def test_cond_unbacked_symint_closure(self):
6314+
from torch.export import Dim
6315+
6316+
class M(torch.nn.Module):
6317+
def forward(self, x, y, z):
6318+
a = y.shape[0]
6319+
b = z.shape[0]
6320+
# c is an unbacked symint in non-strict export
6321+
c = y.sum().item()
6322+
6323+
def true_fn(x):
6324+
return x + a + c
6325+
6326+
def false_fn(x):
6327+
return x + b * z * c
6328+
6329+
# When exporting with non-strict: a and b are symints,
6330+
# so torch.compile need to wrap and trace symint inputs.
6331+
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
6332+
6333+
args = (torch.ones(3, 3), torch.ones(5, dtype=torch.int32), torch.ones(3, 3))
6334+
model = M()
6335+
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
6336+
_ = self._check_export(model, args, dynamic_shapes)
62216337

62226338

62236339
instantiate_parametrized_tests(TestHopSchema)

torch/_dynamo/variables/builder.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -944,40 +944,53 @@ def build_key_value(i, k, v):
944944
):
945945
self.install_guards(GuardBuilder.FUNCTION_MATCH)
946946
return ItertoolsVariable(value, source=self.source)
947-
elif isinstance(value, torch.SymBool):
948-
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the
949-
# user provided SymBool with a SymInt in dynamo.
947+
elif isinstance(value, (torch.SymBool, torch.SymInt)) and not isinstance(
948+
value.node, torch.nested._internal.nested_int.NestedIntNode
949+
):
950+
# Note: this doesn't handle nested symints.
951+
# For SymBool input, we re-use the infra for SymInt by simulating SymBool with a SymInt in dynamo.
950952

951953
# Concretely,
952954
# 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
953955
# so that guards on the SymInts can be effectively applied on the original SymBool in user program.
954956
# 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
955957
# depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
956-
957-
new_source = ConvertIntSource(self.source)
958+
source = (
959+
self.source
960+
if isinstance(value, torch.SymInt)
961+
else ConvertIntSource(self.source)
962+
)
958963
if value.node.has_hint():
959-
value_hint = value.node.require_hint()
960-
961964
new_symint = (
962965
self.tx.output.shape_env.create_unspecified_symint_and_symbol(
963-
int(value_hint),
964-
new_source,
966+
int(value.node.hint),
967+
source,
965968
dynamic_dim=DimDynamic.DYNAMIC,
966969
)
967970
)
968971
else:
969-
# We need to create an unbacked symint to replace the unbacked symbool.
970-
new_symint = self.tx.output.shape_env.create_unbacked_symint()
972+
if isinstance(value, torch.SymBool):
973+
# We need to create an unbacked symint to replace the unbacked symbool.
974+
new_symint = self.tx.output.shape_env.create_unbacked_symint()
975+
else:
976+
# TODO (yidi): we need to figure out a way to propagate the guards
977+
# we accumulated when tracing the subggraph to outer shape_env. For normal symints,
978+
# this is automatically done by evaluating the guards once but this
979+
# will cause data-dependent error when we evaluate the outer unbacked symints.
980+
# The test case that triggers this graph break is test_cond_unbacked_symint_closure
981+
unimplemented(
982+
"unbacked symint input is not supported yet. If you need this feature, please file a github issue."
983+
)
971984

972985
sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
973986
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
974987
type(new_symint),
975988
new_symint,
976-
source=new_source,
989+
source=source,
977990
)
978991

979992
sym_node_proxy.node.meta["grapharg"] = GraphArg(
980-
new_source,
993+
source,
981994
new_symint,
982995
False,
983996
None,
@@ -989,13 +1002,13 @@ def build_key_value(i, k, v):
9891002
assert isinstance(
9901003
sym_expr, sympy.Symbol
9911004
), f"{sym_expr} is not a basic Symbol."
992-
self.tx.output.tracked_fakes.append(
993-
TrackedFake(new_symint, new_source, None)
994-
)
995-
return SymNodeVariable(
996-
sym_node_proxy,
997-
new_symint == 1,
998-
)
1005+
self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None))
1006+
1007+
tracing_symint = (
1008+
new_symint if isinstance(value, torch.SymInt) else new_symint == 1
1009+
) # cast it back to symbool for tracing
1010+
return SymNodeVariable(sym_node_proxy, tracing_symint)
1011+
9991012
elif isinstance(value, (JITFunction, Autotuner)):
10001013
self.install_guards(GuardBuilder.ID_MATCH)
10011014
return TritonKernelVariable(

0 commit comments

Comments
 (0)