Skip to content

Commit

Permalink
[ONNX] Refactor exporter errors (pytorch#135180)
Browse files Browse the repository at this point in the history
Refactor exporter errors to combine old errors and new errors for API consistency.

This PR also

1. Removes the `_C._check_onnx_proto(proto)` call in the old exporter. We don't need the ONNX checker because it is limited.
2. Removes the `OnnxExporterError` defined in the dynamo module. This class unnecessarily stores the onnx program object, making it very bulky. Instead, we revert to use the plain OnnxExporterError defined in the `errors` module and use it as the base class for all errors.
3. Continues to expose `OnnxExporterError` in `torch.onnx` and the rest of the errors in `torch.onnx.errors`.
4. Removes the `CheckerError` and `InvalidExportOptionsError` from `torch.onnx`. This is BC breaking but should have low impact.
5. I did not rename existing errors out of compatibility considerations, even though `ExporterError` would have been more succinct.

Fixes pytorch#135125
Pull Request resolved: pytorch#135180
Approved by: https://github.com/titaiwangms
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 6, 2024
1 parent a15aabc commit 5eebd93
Show file tree
Hide file tree
Showing 14 changed files with 82 additions and 233 deletions.
5 changes: 4 additions & 1 deletion docs/source/onnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
onnx_dynamo_onnxruntime_backend
onnx_torchscript

.. TODO: Temporarily put the onnx errors module here. Update when we revamp the docs.
.. automodule:: torch.onnx.errors
:members:

.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.onnx.errors
.. py:module:: torch.onnx.operators
.. py:module:: torch.onnx.symbolic_caffe2
.. py:module:: torch.onnx.symbolic_helper
Expand Down
3 changes: 0 additions & 3 deletions docs/source/onnx_dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ API Reference
.. autoclass:: torch.onnx.ONNXRuntimeOptions
:members:

.. autoclass:: torch.onnx.InvalidExportOptionsError
:members:

.. autoclass:: torch.onnx.OnnxExporterError
:members:

Expand Down
41 changes: 0 additions & 41 deletions test/onnx/dynamo/test_exporter_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Owner(s): ["module: onnx"]
import io
import os

import onnx

import torch
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
from torch.onnx._internal import _exporter_legacy
from torch.onnx._internal._exporter_legacy import ResolvedExportOptions
from torch.testing._internal import common_utils

Expand Down Expand Up @@ -72,45 +70,6 @@ def test_save_to_existing_buffer_default_serializer(self):
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)

def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName(suffix=".sarif") as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
self.assertTrue(os.path.exists(path))

def test_save_sarif_log_to_file_with_failed_export(self):
class ModelWithExportError(torch.nn.Module):
def forward(self, x):
raise RuntimeError("Export error")

with self.assertRaises(RuntimeError):
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
self.assertTrue(
os.path.exists(_exporter_legacy._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH)
)

def test_onnx_program_accessible_from_exception_when_export_failed(self):
class ModelWithExportError(torch.nn.Module):
def forward(self, x):
raise RuntimeError("Export error")

with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError)
self.assertIsInstance(cm.exception.onnx_program, ONNXProgram)

def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export(
self,
):
class ModelWithExportError(torch.nn.Module):
def forward(self, x):
raise RuntimeError("Export error")

with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
onnx_program = cm.exception.onnx_program
with self.assertRaises(RuntimeError):
onnx_program.model_proto

def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
self,
):
Expand Down
37 changes: 10 additions & 27 deletions test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import torch
from torch import export as torch_export
from torch.onnx import _constants, verification
from torch.onnx._internal.fx import diagnostics
from torch.testing._internal import common_utils
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number
Expand Down Expand Up @@ -286,35 +285,19 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
export_error: Optional[torch.onnx.OnnxExporterError] = None
try:
with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
onnx_program = torch.onnx.dynamo_export(
ref_model,
*ref_input_args,
**ref_input_kwargs,
export_options=torch.onnx.ExportOptions(
dynamic_shapes=self.dynamic_shapes,
diagnostic_options=torch.onnx.DiagnosticOptions(
verbosity_level=logging.DEBUG
),
with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
onnx_program = torch.onnx.dynamo_export(
ref_model,
*ref_input_args,
**ref_input_kwargs,
export_options=torch.onnx.ExportOptions(
dynamic_shapes=self.dynamic_shapes,
diagnostic_options=torch.onnx.DiagnosticOptions(
verbosity_level=logging.DEBUG
),
)
except torch.onnx.OnnxExporterError as e:
export_error = e
onnx_program = e.onnx_program

if diagnostics.is_onnx_diagnostics_log_artifact_enabled():
onnx_program.save_diagnostics(
f"test_report_{self._testMethodName}"
f"_dynamic_axes_{self.dynamic_shapes}"
f"_model_type_{self.model_type}"
".sarif"
),
)

if export_error is not None:
raise export_error

if not skip_dynamic_shapes_check:
assert_dynamic_shapes(onnx_program, self.dynamic_shapes)

Expand Down
14 changes: 0 additions & 14 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,6 @@ def forward(self, input):
export_options=torch.onnx.ExportOptions(onnx_registry=registry),
)

try:
torch.onnx.dynamo_export(
TraceModel(),
x,
export_options=torch.onnx.ExportOptions(onnx_registry=registry),
)
except torch.onnx.OnnxExporterError as e:
assert_has_diagnostics(
e.onnx_program.diagnostic_context,
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
expected_node="aten.mul.Tensor",
)

def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info(
self,
):
Expand Down
27 changes: 0 additions & 27 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,33 +600,6 @@ def forward(self, x, y):
+ ")."
)

def test_onnx_checker_invalid_graph(self):
class CustomAddModule(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)

def symbolic_custom_invalid_add(g, input, other, alpha=None):
return g.op("Add", input, other, invalid_attr_i=1)

torch.onnx.register_custom_op_symbolic(
"::add", symbolic_custom_invalid_add, opset_version=9
)

x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)

test_model = CustomAddModule()
f = io.BytesIO()

try:
with self.assertRaises(torch.onnx.errors.CheckerError):
torch.onnx.export(test_model, (x, y), f, opset_version=9)
finally:
torch.onnx.unregister_custom_op_symbolic("::add", 9)

self.assertTrue(f.getvalue(), "ONNX graph was not exported.")
loaded_model = onnx.load_from_string(f.getvalue())

def test_shape_value_map(self):
class RSoftMax(torch.nn.Module):
def __init__(self, radix, cardinality):
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13582,7 +13582,7 @@ def forward(self, input, grid):
):
if self.opset_version < 20:
with self.assertRaises(
torch.onnx.errors.OnnxExporterError,
torch.onnx.OnnxExporterError,
):
self.run_test(
GridSampleModule(mode, padding_mode, align_corners),
Expand Down
23 changes: 9 additions & 14 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@
"unregister_custom_op_symbolic",
"disable_log",
"enable_log",
# Errors
"CheckerError", # Backwards compatibility
# Base error
"OnnxExporterError",
# Dynamo Exporter
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
"dynamo_export",
"enable_fake_mode",
Expand All @@ -69,7 +67,7 @@
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import CheckerError # Backwards compatibility
from .errors import OnnxExporterError
from .utils import (
_optimize_graph,
_run_symbolic_function,
Expand Down Expand Up @@ -109,8 +107,6 @@
ExportOptions,
ONNXProgram,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
enable_fake_mode,
)
Expand All @@ -120,20 +116,19 @@
import os

# Set namespace for exposed private names
DiagnosticOptions.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
OnnxRegistry.__module__ = "torch.onnx"
DiagnosticOptions.__module__ = "torch.onnx"
is_onnxrt_backend_supported.__module__ = "torch.onnx"
_OrtExecutionProvider.__module__ = "torch.onnx"
_OrtBackendOptions.__module__ = "torch.onnx"
_OrtBackend.__module__ = "torch.onnx"
_OrtBackendOptions.__module__ = "torch.onnx"
_OrtExecutionProvider.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
is_onnxrt_backend_supported.__module__ = "torch.onnx"

producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
Expand Down
60 changes: 17 additions & 43 deletions torch/onnx/_internal/_exporter_legacy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
# mypy: allow-untyped-defs
from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions)
annotations,
)
from __future__ import annotations


__all__ = [
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxRegistry",
"UnsatisfiedDependencyError",
"dynamo_export",
"enable_fake_mode",
]


import abc
import contextlib
Expand All @@ -17,6 +29,7 @@
import torch
import torch._ops
import torch.utils._pytree as pytree
from torch.onnx import errors
from torch.onnx._internal import io_adapter
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import (
Expand Down Expand Up @@ -1061,28 +1074,6 @@ def __init__(self, package_name: str, message: str):
self.package_name = package_name


class OnnxExporterError(RuntimeError):
"""Raised when an ONNX exporter error occurs.
This exception is thrown when there's an error during the ONNX export process.
It encapsulates the :class:`ONNXProgram` object generated until the failure, allowing
access to the partial export results and associated metadata.
"""

onnx_program: Final[ONNXProgram] # type: ignore[misc]

def __init__(self, onnx_program: ONNXProgram, message: str):
"""
Initializes the OnnxExporterError with the given ONNX program and message.
Args:
onnx_program (ONNXProgram): The partial results of the ONNX export.
message (str): The error message to be displayed.
"""
super().__init__(message)
self.onnx_program = onnx_program


class InvalidExportOptionsError(RuntimeError):
"""Raised when user specified an invalid value for the :class:`ExportOptions`."""

Expand Down Expand Up @@ -1232,10 +1223,7 @@ def forward(self, x, bias=None):
"or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). "
f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
)
raise OnnxExporterError(
ONNXProgram._from_failure(e, resolved_export_options.diagnostic_context),
message,
) from e
raise errors.OnnxExporterError(message) from e


def common_pre_export_passes(
Expand Down Expand Up @@ -1313,17 +1301,3 @@ def common_pre_export_passes(
)

return module


__all__ = [
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
"OnnxRegistry",
"UnsatisfiedDependencyError",
"dynamo_export",
"enable_fake_mode",
]
3 changes: 2 additions & 1 deletion torch/onnx/_internal/exporter/_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from onnxscript.ir import convenience as ir_convenience

import torch
from torch.onnx._internal.exporter import _schemas, _tensors, errors
from torch.onnx import errors
from torch.onnx._internal.exporter import _schemas, _tensors


if TYPE_CHECKING:
Expand Down
Loading

0 comments on commit 5eebd93

Please sign in to comment.