7
7
import torch .utils ._pytree as pytree
8
8
from functorch .experimental import control_flow
9
9
from functorch .experimental .control_flow import cond , UnsupportedAliasMutationException
10
+ from torch ._dynamo .testing import normalize_gm
10
11
from torch ._higher_order_ops .associative_scan import (
11
12
_fake_associative_scan ,
12
13
associative_scan ,
@@ -6181,6 +6182,19 @@ def test_while_loop_schema_gen(self):
6181
6182
)
6182
6183
self .assertEqual (schema .parse (str (schema )), schema )
6183
6184
6185
+ # Return the .module() graph str result of non-strict export
6186
+ def _check_export (self , fn , args , dynamic_shapes = None ) -> str :
6187
+ strict_ep = torch .export .export (
6188
+ fn , args , dynamic_shapes = dynamic_shapes , strict = True
6189
+ )
6190
+ non_strict_ep = torch .export .export (
6191
+ fn , args , dynamic_shapes = dynamic_shapes , strict = False
6192
+ )
6193
+ eager_res = fn (* args )
6194
+ self .assertEqual (strict_ep .module ()(* args ), eager_res )
6195
+ self .assertEqual (non_strict_ep .module ()(* args ), eager_res )
6196
+ return normalize_gm (non_strict_ep .module ().print_readable (print_output = False ))
6197
+
6184
6198
@skipIfTorchDynamo ("Skip because dynamo cannot trace torch.export." )
6185
6199
@torch ._dynamo .config .patch (capture_scalar_outputs = True )
6186
6200
def test_cond_eager_run_with_item (self ):
@@ -6204,20 +6218,122 @@ def false_fn(x):
6204
6218
)
6205
6219
model = M ()
6206
6220
ep = torch .export .export (model , args )
6221
+ graph_str = self ._check_export (model , args , None )
6207
6222
self .assertExpectedInline (
6208
- ep . module (). code . strip () ,
6223
+ graph_str ,
6209
6224
"""\
6210
- def forward(self, a, b1, b2, c):
6211
- a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
6212
- true_graph_0 = self.true_graph_0
6213
- false_graph_0 = self.false_graph_0
6214
- cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, [c, b1, b2]); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
6215
- getitem = cond[0]; cond = None
6216
- mul = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
6217
- return pytree.tree_unflatten((mul,), self._out_spec)""" , # noqa: B950
6225
+ class GraphModule(torch.nn.Module):
6226
+ def forward(self, a, b1, b2, c):
6227
+ a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
6228
+
6229
+ a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
6230
+ true_graph_0 = self.true_graph_0
6231
+ false_graph_0 = self.false_graph_0
6232
+ cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, [c, b1, b2]); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
6233
+ getitem: "f32[10]" = cond[0]; cond = None
6234
+
6235
+ mul: "f32[10]" = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
6236
+ return pytree.tree_unflatten((mul,), self._out_spec)
6237
+
6238
+ class true_graph_0(torch.nn.Module):
6239
+ def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
6240
+ item: "Sym(u0)" = torch.ops.aten.item.default(b1); b1 = None
6241
+
6242
+ mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
6243
+ return (mul,)
6244
+
6245
+ class false_graph_0(torch.nn.Module):
6246
+ def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
6247
+ item: "Sym(u1)" = torch.ops.aten.item.default(b2); b2 = None
6248
+
6249
+ mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
6250
+ return (mul,)
6251
+ """ , # noqa: B950
6252
+ )
6253
+
6254
+ @skipIfTorchDynamo ("Skip because dynamo cannot trace torch.export." )
6255
+ def test_cond_symint_closure (self ):
6256
+ from torch .export import Dim
6257
+
6258
+ class M (torch .nn .Module ):
6259
+ def forward (self , x , y , z ):
6260
+ a = y .shape [0 ]
6261
+ b = z .shape [0 ]
6262
+
6263
+ def true_fn (x ):
6264
+ return x + a
6265
+
6266
+ def false_fn (x ):
6267
+ return x + b * z
6268
+
6269
+ # When exporting with non-strict: a and b are symints,
6270
+ # so torch.compile need to wrap and trace symint inputs.
6271
+ return torch .cond (x .shape [0 ] > 5 , true_fn , false_fn , (x ,))
6272
+
6273
+ args = (torch .ones (3 , 3 ), torch .ones (5 ), torch .ones (3 , 3 ))
6274
+ model = M ()
6275
+ dynamic_shapes = {"x" : {0 : Dim ("d" )}, "y" : {0 : Dim ("d1" )}, "z" : {0 : Dim ("d" )}}
6276
+ non_strict_graph_str = self ._check_export (model , args , dynamic_shapes )
6277
+ self .assertExpectedInline (
6278
+ non_strict_graph_str ,
6279
+ """\
6280
+ class GraphModule(torch.nn.Module):
6281
+ def forward(self, x, y, z):
6282
+ x: "f32[s0, 3]"; y: "f32[s1]"; z: "f32[s0, 3]";
6283
+
6284
+ x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
6285
+ sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
6286
+ sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(y, 0); y = None
6287
+
6288
+ gt: "Sym(s0 > 5)" = sym_size_int_3 > 5
6289
+
6290
+ true_graph_0 = self.true_graph_0
6291
+ false_graph_0 = self.false_graph_0
6292
+ cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, sym_size_int_4, sym_size_int_3, z]); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_3 = z = None
6293
+ getitem: "f32[s0, 3]" = cond[0]; cond = None
6294
+ return pytree.tree_unflatten((getitem,), self._out_spec)
6295
+
6296
+ class true_graph_0(torch.nn.Module):
6297
+ def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
6298
+ add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
6299
+ return (add,)
6300
+
6301
+ class false_graph_0(torch.nn.Module):
6302
+ def forward(self, x: "f32[s0, 3]", sym_size_int_4: "Sym(s1)", sym_size_int_3: "Sym(s0)", z: "f32[s0, 3]"):
6303
+ mul: "f32[s0, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_3); z = sym_size_int_3 = None
6304
+
6305
+ add: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
6306
+ return (add,)
6307
+ """ , # noqa: B950
6218
6308
)
6219
- expected_output = model (* args )
6220
- self .assertEqual (expected_output , x * 3 * 2 )
6309
+
6310
+ # unbacked symint inputs are created during non-strict export,
6311
+ # which causes a graph break
6312
+ @unittest .expectedFailure
6313
+ def test_cond_unbacked_symint_closure (self ):
6314
+ from torch .export import Dim
6315
+
6316
+ class M (torch .nn .Module ):
6317
+ def forward (self , x , y , z ):
6318
+ a = y .shape [0 ]
6319
+ b = z .shape [0 ]
6320
+ # c is an unbacked symint in non-strict export
6321
+ c = y .sum ().item ()
6322
+
6323
+ def true_fn (x ):
6324
+ return x + a + c
6325
+
6326
+ def false_fn (x ):
6327
+ return x + b * z * c
6328
+
6329
+ # When exporting with non-strict: a and b are symints,
6330
+ # so torch.compile need to wrap and trace symint inputs.
6331
+ return torch .cond (x .shape [0 ] > 5 , true_fn , false_fn , (x ,))
6332
+
6333
+ args = (torch .ones (3 , 3 ), torch .ones (5 , dtype = torch .int32 ), torch .ones (3 , 3 ))
6334
+ model = M ()
6335
+ dynamic_shapes = {"x" : {0 : Dim ("d" )}, "y" : {0 : Dim ("d1" )}, "z" : {0 : Dim ("d" )}}
6336
+ _ = self ._check_export (model , args , dynamic_shapes )
6221
6337
6222
6338
6223
6339
instantiate_parametrized_tests (TestHopSchema )
0 commit comments