Skip to content

Commit 2831af3

Browse files
Revert "[ONNX] Remove deprecated export_to_pretty_string (pytorch#137790)"
This reverts commit d0628a7. Reverted pytorch#137790 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#137789 (comment)))
1 parent dac0b4e commit 2831af3

File tree

4 files changed

+119
-48
lines changed

4 files changed

+119
-48
lines changed

docs/source/onnx_torchscript.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ Functions
697697
^^^^^^^^^
698698

699699
.. autofunction:: export
700+
.. autofunction:: export_to_pretty_string
700701
.. autofunction:: register_custom_op_symbolic
701702
.. autofunction:: unregister_custom_op_symbolic
702703
.. autofunction:: select_model_mode_for_export

test/onnx/test_pytorch_onnx_no_runtime.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, x):
8383

8484
x = torch.ones(3, 3)
8585
f = io.BytesIO()
86-
torch.onnx.export(AddmmModel(), x, f)
86+
torch.onnx.export(AddmmModel(), x, f, verbose=False)
8787

8888
def test_onnx_transpose_incomplete_tensor_type(self):
8989
# Smoke test to get us into the state where we are attempting to export
@@ -115,8 +115,7 @@ def foo(x):
115115

116116
traced = torch.jit.trace(foo, (torch.rand([2])))
117117

118-
f = io.BytesIO()
119-
torch.onnx.export(traced, (torch.rand([2]),), f)
118+
torch.onnx.export_to_pretty_string(traced, (torch.rand([2]),))
120119

121120
def test_onnx_export_script_module(self):
122121
class ModuleToExport(torch.jit.ScriptModule):
@@ -126,8 +125,7 @@ def forward(self, x):
126125
return x + x
127126

128127
mte = ModuleToExport()
129-
f = io.BytesIO()
130-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
128+
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
131129

132130
@common_utils.suppress_warnings
133131
def test_onnx_export_func_with_warnings(self):
@@ -140,8 +138,9 @@ def forward(self, x):
140138
return func_with_warning(x)
141139

142140
# no exception
143-
f = io.BytesIO()
144-
torch.onnx.export(WarningTest(), torch.randn(42), f)
141+
torch.onnx.export_to_pretty_string(
142+
WarningTest(), torch.randn(42), verbose=False
143+
)
145144

146145
def test_onnx_export_script_python_fail(self):
147146
class PythonModule(torch.jit.ScriptModule):
@@ -162,7 +161,7 @@ def forward(self, x):
162161
mte = ModuleToExport()
163162
f = io.BytesIO()
164163
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
165-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
164+
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
166165

167166
def test_onnx_export_script_inline_trace(self):
168167
class ModuleToInline(torch.nn.Module):
@@ -180,8 +179,7 @@ def forward(self, x):
180179
return y + y
181180

182181
mte = ModuleToExport()
183-
f = io.BytesIO()
184-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
182+
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
185183

186184
def test_onnx_export_script_inline_script(self):
187185
class ModuleToInline(torch.jit.ScriptModule):
@@ -200,8 +198,7 @@ def forward(self, x):
200198
return y + y
201199

202200
mte = ModuleToExport()
203-
f = io.BytesIO()
204-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
201+
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
205202

206203
def test_onnx_export_script_module_loop(self):
207204
class ModuleToExport(torch.jit.ScriptModule):
@@ -215,8 +212,7 @@ def forward(self, x):
215212
return x
216213

217214
mte = ModuleToExport()
218-
f = io.BytesIO()
219-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
215+
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
220216

221217
@common_utils.suppress_warnings
222218
def test_onnx_export_script_truediv(self):
@@ -228,8 +224,9 @@ def forward(self, x):
228224

229225
mte = ModuleToExport()
230226

231-
f = io.BytesIO()
232-
torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f)
227+
torch.onnx.export_to_pretty_string(
228+
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), verbose=False
229+
)
233230

234231
def test_onnx_export_script_non_alpha_add_sub(self):
235232
class ModuleToExport(torch.jit.ScriptModule):
@@ -239,8 +236,7 @@ def forward(self, x):
239236
return bs - 1
240237

241238
mte = ModuleToExport()
242-
f = io.BytesIO()
243-
torch.onnx.export(mte, (torch.rand(3, 4),), f)
239+
torch.onnx.export_to_pretty_string(mte, (torch.rand(3, 4),), verbose=False)
244240

245241
def test_onnx_export_script_module_if(self):
246242
class ModuleToExport(torch.jit.ScriptModule):
@@ -251,8 +247,7 @@ def forward(self, x):
251247
return x
252248

253249
mte = ModuleToExport()
254-
f = io.BytesIO()
255-
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)
250+
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
256251

257252
def test_onnx_export_script_inline_params(self):
258253
class ModuleToInline(torch.jit.ScriptModule):
@@ -282,8 +277,7 @@ def forward(self, x):
282277
torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)
283278
)
284279
self.assertEqual(result, reference)
285-
f = io.BytesIO()
286-
torch.onnx.export(mte, (torch.ones(2, 3),), f)
280+
torch.onnx.export_to_pretty_string(mte, (torch.ones(2, 3),), verbose=False)
287281

288282
def test_onnx_export_speculate(self):
289283
class Foo(torch.jit.ScriptModule):
@@ -318,10 +312,8 @@ def transpose(x):
318312
f1 = Foo(transpose)
319313
f2 = Foo(linear)
320314

321-
f = io.BytesIO()
322-
torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f)
323-
f = io.BytesIO()
324-
torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f)
315+
torch.onnx.export_to_pretty_string(f1, (torch.ones(1, 10, dtype=torch.float),))
316+
torch.onnx.export_to_pretty_string(f2, (torch.ones(1, 10, dtype=torch.float),))
325317

326318
def test_onnx_export_shape_reshape(self):
327319
class Foo(torch.nn.Module):
@@ -334,20 +326,17 @@ def forward(self, x):
334326
return reshaped
335327

336328
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
337-
f = io.BytesIO()
338-
torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f)
329+
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)))
339330

340331
def test_listconstruct_erasure(self):
341332
class FooMod(torch.nn.Module):
342333
def forward(self, x):
343334
mask = x < 0.0
344335
return x[mask]
345336

346-
f = io.BytesIO()
347-
torch.onnx.export(
337+
torch.onnx.export_to_pretty_string(
348338
FooMod(),
349339
(torch.rand(3, 4),),
350-
f,
351340
add_node_names=False,
352341
do_constant_folding=False,
353342
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
@@ -362,10 +351,13 @@ def forward(self, x):
362351
retval += torch.sum(x[0:i], dim=0)
363352
return retval
364353

354+
mod = DynamicSliceExportMod()
355+
365356
input = torch.rand(3, 4, 5)
366357

367-
f = io.BytesIO()
368-
torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10)
358+
torch.onnx.export_to_pretty_string(
359+
DynamicSliceExportMod(), (input,), opset_version=10
360+
)
369361

370362
def test_export_dict(self):
371363
class DictModule(torch.nn.Module):
@@ -376,12 +368,10 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]:
376368
mod = DictModule()
377369
mod.train(False)
378370

379-
f = io.BytesIO()
380-
torch.onnx.export(mod, (x_in,), f)
371+
torch.onnx.export_to_pretty_string(mod, (x_in,))
381372

382373
with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
383-
f = io.BytesIO()
384-
torch.onnx.export(torch.jit.script(mod), (x_in,), f)
374+
torch.onnx.export_to_pretty_string(torch.jit.script(mod), (x_in,))
385375

386376
def test_source_range_propagation(self):
387377
class ExpandingModule(torch.nn.Module):
@@ -507,11 +497,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]):
507497
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
508498

509499
with self.assertRaises(RuntimeError) as cm:
510-
f = io.BytesIO()
500+
onnx_model = io.BytesIO()
511501
torch.onnx.export(
512502
model,
513503
(box_regression, proposal),
514-
f,
504+
onnx_model,
515505
)
516506

517507
def test_initializer_sequence(self):
@@ -647,7 +637,7 @@ def forward(self, x):
647637

648638
x = torch.randn(1, 2, 3, requires_grad=True)
649639
f = io.BytesIO()
650-
torch.onnx.export(Model(), (x,), f)
640+
torch.onnx.export(Model(), x, f)
651641
model = onnx.load(f)
652642
model.ir_version = 0
653643

@@ -754,7 +744,7 @@ def forward(self, x):
754744

755745
f = io.BytesIO()
756746
with warnings.catch_warnings(record=True):
757-
torch.onnx.export(MyDrop(), (eg,), f)
747+
torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
758748

759749
def test_pack_padded_pad_packed_trace(self):
760750
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
@@ -801,7 +791,7 @@ def forward(self, x, seq_lens):
801791
self.assertEqual(grad, grad_traced)
802792

803793
f = io.BytesIO()
804-
torch.onnx.export(m, (x, seq_lens), f)
794+
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
805795

806796
# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
807797
@common_utils.suppress_warnings
@@ -861,7 +851,7 @@ def forward(self, x, seq_lens):
861851
self.assertEqual(grad, grad_traced)
862852

863853
f = io.BytesIO()
864-
torch.onnx.export(m, (x, seq_lens), f)
854+
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
865855

866856
def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self):
867857
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
@@ -941,8 +931,7 @@ class Mod(torch.nn.Module):
941931
def forward(self, x, w):
942932
return torch.matmul(x, w).detach()
943933

944-
f = io.BytesIO()
945-
torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
934+
torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))
946935

947936
def test_aten_fallback_must_fallback(self):
948937
class ModelWithAtenNotONNXOp(torch.nn.Module):
@@ -1099,12 +1088,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size):
10991088
torch.onnx.register_custom_op_symbolic(
11001089
"torch_scatter::scatter_max", sym_scatter_max, 1
11011090
)
1102-
f = io.BytesIO()
11031091
with torch.no_grad():
11041092
torch.onnx.export(
11051093
m,
11061094
(src, idx),
1107-
f,
1095+
"mymodel.onnx",
1096+
verbose=False,
11081097
opset_version=13,
11091098
custom_opsets={"torch_scatter": 1},
11101099
do_constant_folding=True,
@@ -1187,7 +1176,7 @@ def forward(self, x):
11871176
model = Net(C).cuda().half()
11881177
x = torch.randn(N, C).cuda().half()
11891178
f = io.BytesIO()
1190-
torch.onnx.export(model, (x,), f, opset_version=14)
1179+
torch.onnx.export(model, x, f, opset_version=14)
11911180
onnx_model = onnx.load_from_string(f.getvalue())
11921181
const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"]
11931182
self.assertNotEqual(len(const_node), 0)

torch/onnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"JitScalarType",
3131
# Public functions
3232
"export",
33+
"export_to_pretty_string",
3334
"is_in_onnx_export",
3435
"select_model_mode_for_export",
3536
"register_custom_op_symbolic",
@@ -67,6 +68,7 @@
6768
from .utils import (
6869
_run_symbolic_function,
6970
_run_symbolic_method,
71+
export_to_pretty_string,
7072
is_in_onnx_export,
7173
register_custom_op_symbolic,
7274
select_model_mode_for_export,

torch/onnx/utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"model_signature",
3636
"warn_on_static_input_change",
3737
"unpack_quantized_tensor",
38+
"export_to_pretty_string",
3839
"unconvertible_ops",
3940
"register_custom_op_symbolic",
4041
"unregister_custom_op_symbolic",
@@ -1139,6 +1140,84 @@ def _model_to_graph(
11391140
return graph, params_dict, torch_out
11401141

11411142

1143+
@torch._disable_dynamo
1144+
@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead")
1145+
def export_to_pretty_string(
1146+
model,
1147+
args,
1148+
export_params=True,
1149+
verbose=False,
1150+
training=_C_onnx.TrainingMode.EVAL,
1151+
input_names=None,
1152+
output_names=None,
1153+
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
1154+
export_type=None,
1155+
google_printer=False,
1156+
opset_version=None,
1157+
keep_initializers_as_inputs=None,
1158+
custom_opsets=None,
1159+
add_node_names=True,
1160+
do_constant_folding=True,
1161+
dynamic_axes=None,
1162+
):
1163+
"""Similar to :func:`export`, but returns a text representation of the ONNX model.
1164+
1165+
Only differences in args listed below. All other args are the same
1166+
as :func:`export`.
1167+
1168+
Args:
1169+
add_node_names (bool, default True): Whether or not to set
1170+
NodeProto.name. This makes no difference unless
1171+
``google_printer=True``.
1172+
google_printer (bool, default False): If False, will return a custom,
1173+
compact representation of the model. If True will return the
1174+
protobuf's `Message::DebugString()`, which is more verbose.
1175+
1176+
Returns:
1177+
A UTF-8 str containing a human-readable representation of the ONNX model.
1178+
"""
1179+
if opset_version is None:
1180+
opset_version = _constants.ONNX_DEFAULT_OPSET
1181+
if custom_opsets is None:
1182+
custom_opsets = {}
1183+
GLOBALS.export_onnx_opset_version = opset_version
1184+
GLOBALS.operator_export_type = operator_export_type
1185+
1186+
with exporter_context(model, training, verbose):
1187+
val_keep_init_as_ip = _decide_keep_init_as_input(
1188+
keep_initializers_as_inputs, operator_export_type, opset_version
1189+
)
1190+
val_add_node_names = _decide_add_node_names(
1191+
add_node_names, operator_export_type
1192+
)
1193+
val_do_constant_folding = _decide_constant_folding(
1194+
do_constant_folding, operator_export_type, training
1195+
)
1196+
args = _decide_input_format(model, args)
1197+
graph, params_dict, torch_out = _model_to_graph(
1198+
model,
1199+
args,
1200+
verbose,
1201+
input_names,
1202+
output_names,
1203+
operator_export_type,
1204+
val_do_constant_folding,
1205+
training=training,
1206+
dynamic_axes=dynamic_axes,
1207+
)
1208+
1209+
return graph._pretty_print_onnx( # type: ignore[attr-defined]
1210+
params_dict,
1211+
opset_version,
1212+
False,
1213+
operator_export_type,
1214+
google_printer,
1215+
val_keep_init_as_ip,
1216+
custom_opsets,
1217+
val_add_node_names,
1218+
)
1219+
1220+
11421221
@_deprecation.deprecated("2.5", "the future", "avoid using this function")
11431222
def unconvertible_ops(
11441223
model,

0 commit comments

Comments
 (0)