@@ -83,7 +83,7 @@ def forward(self, x):
83
83
84
84
x = torch .ones (3 , 3 )
85
85
f = io .BytesIO ()
86
- torch .onnx .export (AddmmModel (), x , f )
86
+ torch .onnx .export (AddmmModel (), x , f , verbose = False )
87
87
88
88
def test_onnx_transpose_incomplete_tensor_type (self ):
89
89
# Smoke test to get us into the state where we are attempting to export
@@ -115,8 +115,7 @@ def foo(x):
115
115
116
116
traced = torch .jit .trace (foo , (torch .rand ([2 ])))
117
117
118
- f = io .BytesIO ()
119
- torch .onnx .export (traced , (torch .rand ([2 ]),), f )
118
+ torch .onnx .export_to_pretty_string (traced , (torch .rand ([2 ]),))
120
119
121
120
def test_onnx_export_script_module (self ):
122
121
class ModuleToExport (torch .jit .ScriptModule ):
@@ -126,8 +125,7 @@ def forward(self, x):
126
125
return x + x
127
126
128
127
mte = ModuleToExport ()
129
- f = io .BytesIO ()
130
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
128
+ torch .onnx .export_to_pretty_string (mte , (torch .zeros (1 , 2 , 3 ),), verbose = False )
131
129
132
130
@common_utils .suppress_warnings
133
131
def test_onnx_export_func_with_warnings (self ):
@@ -140,8 +138,9 @@ def forward(self, x):
140
138
return func_with_warning (x )
141
139
142
140
# no exception
143
- f = io .BytesIO ()
144
- torch .onnx .export (WarningTest (), torch .randn (42 ), f )
141
+ torch .onnx .export_to_pretty_string (
142
+ WarningTest (), torch .randn (42 ), verbose = False
143
+ )
145
144
146
145
def test_onnx_export_script_python_fail (self ):
147
146
class PythonModule (torch .jit .ScriptModule ):
@@ -162,7 +161,7 @@ def forward(self, x):
162
161
mte = ModuleToExport ()
163
162
f = io .BytesIO ()
164
163
with self .assertRaisesRegex (RuntimeError , "Couldn't export Python" ):
165
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
164
+ torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f , verbose = False )
166
165
167
166
def test_onnx_export_script_inline_trace (self ):
168
167
class ModuleToInline (torch .nn .Module ):
@@ -180,8 +179,7 @@ def forward(self, x):
180
179
return y + y
181
180
182
181
mte = ModuleToExport ()
183
- f = io .BytesIO ()
184
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
182
+ torch .onnx .export_to_pretty_string (mte , (torch .zeros (1 , 2 , 3 ),), verbose = False )
185
183
186
184
def test_onnx_export_script_inline_script (self ):
187
185
class ModuleToInline (torch .jit .ScriptModule ):
@@ -200,8 +198,7 @@ def forward(self, x):
200
198
return y + y
201
199
202
200
mte = ModuleToExport ()
203
- f = io .BytesIO ()
204
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
201
+ torch .onnx .export_to_pretty_string (mte , (torch .zeros (1 , 2 , 3 ),), verbose = False )
205
202
206
203
def test_onnx_export_script_module_loop (self ):
207
204
class ModuleToExport (torch .jit .ScriptModule ):
@@ -215,8 +212,7 @@ def forward(self, x):
215
212
return x
216
213
217
214
mte = ModuleToExport ()
218
- f = io .BytesIO ()
219
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
215
+ torch .onnx .export_to_pretty_string (mte , (torch .zeros (1 , 2 , 3 ),), verbose = False )
220
216
221
217
@common_utils .suppress_warnings
222
218
def test_onnx_export_script_truediv (self ):
@@ -228,8 +224,9 @@ def forward(self, x):
228
224
229
225
mte = ModuleToExport ()
230
226
231
- f = io .BytesIO ()
232
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 , dtype = torch .float ),), f )
227
+ torch .onnx .export_to_pretty_string (
228
+ mte , (torch .zeros (1 , 2 , 3 , dtype = torch .float ),), verbose = False
229
+ )
233
230
234
231
def test_onnx_export_script_non_alpha_add_sub (self ):
235
232
class ModuleToExport (torch .jit .ScriptModule ):
@@ -239,8 +236,7 @@ def forward(self, x):
239
236
return bs - 1
240
237
241
238
mte = ModuleToExport ()
242
- f = io .BytesIO ()
243
- torch .onnx .export (mte , (torch .rand (3 , 4 ),), f )
239
+ torch .onnx .export_to_pretty_string (mte , (torch .rand (3 , 4 ),), verbose = False )
244
240
245
241
def test_onnx_export_script_module_if (self ):
246
242
class ModuleToExport (torch .jit .ScriptModule ):
@@ -251,8 +247,7 @@ def forward(self, x):
251
247
return x
252
248
253
249
mte = ModuleToExport ()
254
- f = io .BytesIO ()
255
- torch .onnx .export (mte , (torch .zeros (1 , 2 , 3 ),), f )
250
+ torch .onnx .export_to_pretty_string (mte , (torch .zeros (1 , 2 , 3 ),), verbose = False )
256
251
257
252
def test_onnx_export_script_inline_params (self ):
258
253
class ModuleToInline (torch .jit .ScriptModule ):
@@ -282,8 +277,7 @@ def forward(self, x):
282
277
torch .mm (torch .zeros (2 , 3 ), torch .ones (3 , 3 )), torch .ones (3 , 4 )
283
278
)
284
279
self .assertEqual (result , reference )
285
- f = io .BytesIO ()
286
- torch .onnx .export (mte , (torch .ones (2 , 3 ),), f )
280
+ torch .onnx .export_to_pretty_string (mte , (torch .ones (2 , 3 ),), verbose = False )
287
281
288
282
def test_onnx_export_speculate (self ):
289
283
class Foo (torch .jit .ScriptModule ):
@@ -318,10 +312,8 @@ def transpose(x):
318
312
f1 = Foo (transpose )
319
313
f2 = Foo (linear )
320
314
321
- f = io .BytesIO ()
322
- torch .onnx .export (f1 , (torch .ones (1 , 10 , dtype = torch .float ),), f )
323
- f = io .BytesIO ()
324
- torch .onnx .export (f2 , (torch .ones (1 , 10 , dtype = torch .float ),), f )
315
+ torch .onnx .export_to_pretty_string (f1 , (torch .ones (1 , 10 , dtype = torch .float ),))
316
+ torch .onnx .export_to_pretty_string (f2 , (torch .ones (1 , 10 , dtype = torch .float ),))
325
317
326
318
def test_onnx_export_shape_reshape (self ):
327
319
class Foo (torch .nn .Module ):
@@ -334,20 +326,17 @@ def forward(self, x):
334
326
return reshaped
335
327
336
328
foo = torch .jit .trace (Foo (), torch .zeros (1 , 2 , 3 ))
337
- f = io .BytesIO ()
338
- torch .onnx .export (foo , (torch .zeros (1 , 2 , 3 )), f )
329
+ torch .onnx .export_to_pretty_string (foo , (torch .zeros (1 , 2 , 3 )))
339
330
340
331
def test_listconstruct_erasure (self ):
341
332
class FooMod (torch .nn .Module ):
342
333
def forward (self , x ):
343
334
mask = x < 0.0
344
335
return x [mask ]
345
336
346
- f = io .BytesIO ()
347
- torch .onnx .export (
337
+ torch .onnx .export_to_pretty_string (
348
338
FooMod (),
349
339
(torch .rand (3 , 4 ),),
350
- f ,
351
340
add_node_names = False ,
352
341
do_constant_folding = False ,
353
342
operator_export_type = torch .onnx .OperatorExportTypes .ONNX_ATEN_FALLBACK ,
@@ -362,10 +351,13 @@ def forward(self, x):
362
351
retval += torch .sum (x [0 :i ], dim = 0 )
363
352
return retval
364
353
354
+ mod = DynamicSliceExportMod ()
355
+
365
356
input = torch .rand (3 , 4 , 5 )
366
357
367
- f = io .BytesIO ()
368
- torch .onnx .export (DynamicSliceExportMod (), (input ,), f , opset_version = 10 )
358
+ torch .onnx .export_to_pretty_string (
359
+ DynamicSliceExportMod (), (input ,), opset_version = 10
360
+ )
369
361
370
362
def test_export_dict (self ):
371
363
class DictModule (torch .nn .Module ):
@@ -376,12 +368,10 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]:
376
368
mod = DictModule ()
377
369
mod .train (False )
378
370
379
- f = io .BytesIO ()
380
- torch .onnx .export (mod , (x_in ,), f )
371
+ torch .onnx .export_to_pretty_string (mod , (x_in ,))
381
372
382
373
with self .assertRaisesRegex (RuntimeError , r"DictConstruct.+is not supported." ):
383
- f = io .BytesIO ()
384
- torch .onnx .export (torch .jit .script (mod ), (x_in ,), f )
374
+ torch .onnx .export_to_pretty_string (torch .jit .script (mod ), (x_in ,))
385
375
386
376
def test_source_range_propagation (self ):
387
377
class ExpandingModule (torch .nn .Module ):
@@ -507,11 +497,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]):
507
497
proposal = [torch .randn (2 , 4 ), torch .randn (2 , 4 )]
508
498
509
499
with self .assertRaises (RuntimeError ) as cm :
510
- f = io .BytesIO ()
500
+ onnx_model = io .BytesIO ()
511
501
torch .onnx .export (
512
502
model ,
513
503
(box_regression , proposal ),
514
- f ,
504
+ onnx_model ,
515
505
)
516
506
517
507
def test_initializer_sequence (self ):
@@ -647,7 +637,7 @@ def forward(self, x):
647
637
648
638
x = torch .randn (1 , 2 , 3 , requires_grad = True )
649
639
f = io .BytesIO ()
650
- torch .onnx .export (Model (), ( x ,) , f )
640
+ torch .onnx .export (Model (), x , f )
651
641
model = onnx .load (f )
652
642
model .ir_version = 0
653
643
@@ -754,7 +744,7 @@ def forward(self, x):
754
744
755
745
f = io .BytesIO ()
756
746
with warnings .catch_warnings (record = True ):
757
- torch .onnx .export (MyDrop (), (eg ,), f )
747
+ torch .onnx .export (MyDrop (), (eg ,), f , verbose = False )
758
748
759
749
def test_pack_padded_pad_packed_trace (self ):
760
750
from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
@@ -801,7 +791,7 @@ def forward(self, x, seq_lens):
801
791
self .assertEqual (grad , grad_traced )
802
792
803
793
f = io .BytesIO ()
804
- torch .onnx .export (m , (x , seq_lens ), f )
794
+ torch .onnx .export (m , (x , seq_lens ), f , verbose = False )
805
795
806
796
# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
807
797
@common_utils .suppress_warnings
@@ -861,7 +851,7 @@ def forward(self, x, seq_lens):
861
851
self .assertEqual (grad , grad_traced )
862
852
863
853
f = io .BytesIO ()
864
- torch .onnx .export (m , (x , seq_lens ), f )
854
+ torch .onnx .export (m , (x , seq_lens ), f , verbose = False )
865
855
866
856
def test_pushpackingpastrnn_in_peephole_create_own_gather_input (self ):
867
857
from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
@@ -941,8 +931,7 @@ class Mod(torch.nn.Module):
941
931
def forward (self , x , w ):
942
932
return torch .matmul (x , w ).detach ()
943
933
944
- f = io .BytesIO ()
945
- torch .onnx .export (Mod (), (torch .rand (3 , 4 ), torch .rand (4 , 5 )), f )
934
+ torch .onnx .export_to_pretty_string (Mod (), (torch .rand (3 , 4 ), torch .rand (4 , 5 )))
946
935
947
936
def test_aten_fallback_must_fallback (self ):
948
937
class ModelWithAtenNotONNXOp (torch .nn .Module ):
@@ -1099,12 +1088,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size):
1099
1088
torch .onnx .register_custom_op_symbolic (
1100
1089
"torch_scatter::scatter_max" , sym_scatter_max , 1
1101
1090
)
1102
- f = io .BytesIO ()
1103
1091
with torch .no_grad ():
1104
1092
torch .onnx .export (
1105
1093
m ,
1106
1094
(src , idx ),
1107
- f ,
1095
+ "mymodel.onnx" ,
1096
+ verbose = False ,
1108
1097
opset_version = 13 ,
1109
1098
custom_opsets = {"torch_scatter" : 1 },
1110
1099
do_constant_folding = True ,
@@ -1187,7 +1176,7 @@ def forward(self, x):
1187
1176
model = Net (C ).cuda ().half ()
1188
1177
x = torch .randn (N , C ).cuda ().half ()
1189
1178
f = io .BytesIO ()
1190
- torch .onnx .export (model , ( x ,) , f , opset_version = 14 )
1179
+ torch .onnx .export (model , x , f , opset_version = 14 )
1191
1180
onnx_model = onnx .load_from_string (f .getvalue ())
1192
1181
const_node = [n for n in onnx_model .graph .node if n .op_type == "Constant" ]
1193
1182
self .assertNotEqual (len (const_node ), 0 )
0 commit comments