Skip to content

to_onnx return ONNXProgram #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bde9614
feat: return `ONNXProgram` when exporting with dynamo=True.
GdoongMathew May 11, 2025
a966c0b
test: add to_onnx(dynamo=True) unittests.
GdoongMathew May 11, 2025
e7342e3
fix: add ignore filter in pyproject.toml
GdoongMathew May 11, 2025
3ee3ea9
fix: change the return type annotation of `to_onnx`.
GdoongMathew May 21, 2025
bc81215
test: add parametrized `dynamo` to test `test_if_inference_output_is_…
GdoongMathew May 21, 2025
236f1a0
test: add difference check in `test_model_return_type`.
GdoongMathew May 21, 2025
019125d
fix: fix unittest.
GdoongMathew May 21, 2025
9f5e604
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew May 30, 2025
791d777
deps: bump typing_extension for onnxscript.
GdoongMathew Jun 2, 2025
453e63f
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 2, 2025
acdf3c1
deps: bump typing_extension for onnxscript.
GdoongMathew Jun 2, 2025
e046d27
deps: bump onnxscript upper bound.
GdoongMathew Jun 3, 2025
a0a7d1f
test: add test `test_model_onnx_export_missing_onnxscript`.
GdoongMathew Jun 5, 2025
7aae865
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 6, 2025
aca9fd1
revert typing-extension bump.
GdoongMathew Jun 7, 2025
1396f35
lower the min_torch version in unittest.
GdoongMathew Jun 7, 2025
8f050ea
feat: enable ONNXProgram export on torch 2.5.0
GdoongMathew Jun 16, 2025
3938c73
Merge branch 'master' into feat/dynamo_export_onnx
GdoongMathew Jun 16, 2025
ce3e6b7
extensions
Borda Jun 16, 2025
c31a3f6
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 16, 2025
40b1449
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 16, 2025
a470fe8
ds
Borda Jun 18, 2025
9e4a494
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2025
b08e465
Merge branch 'master' into feat/dynamo_export_onnx
Borda Jun 18, 2025
67af423
dep: test fixing pydantic version.
GdoongMathew Jun 18, 2025
0e4cb80
Revert "dep: test fixing pydantic version."
GdoongMathew Jun 18, 2025
b26072d
dep: add serve deps.
GdoongMathew Jun 18, 2025
d1b8597
ci: test.
GdoongMathew Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"requirements/pytorch/extra.txt",
"requirements/pytorch/strategies.txt",
"requirements/pytorch/examples.txt",
"requirements/pytorch/serve.txt",
),
"fabric": (
"requirements/fabric/base.txt",
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ jobs:
- name: Install package & dependencies
timeout-minutes: 20
run: |
pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \
pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this should be changed so that it could install the right extra packages ...

--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}"
pip list
- name: Dump handy wheels
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ jobs:
- name: Install package & dependencies
timeout-minutes: 20
run: |
pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \
pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \
-r requirements/_integrations/accelerators.txt \
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}"
pip list
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ markers = [
]
filterwarnings = [
"error::FutureWarning",
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
]
xfail_strict = true
junit_duration_report = "call"
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
torch >=2.1.0, <2.8.0
fsspec[http] >=2022.5.0, <2025.6.0
packaging >=20.0, <=25.0
typing-extensions >=4.4.0, <4.15.0
typing-extensions >4.4.0, <4.15.0
lightning-utilities >=0.10.0, <0.15.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2025.6.0
torchmetrics >=0.7.0, <1.8.0
packaging >=20.0, <=25.0
typing-extensions >=4.4.0, <4.15.0
typing-extensions >4.4.0, <4.15.0
lightning-utilities >=0.10.0, <0.15.0
2 changes: 2 additions & 0 deletions requirements/pytorch/serve.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
fastapi >= 0.98.0
pydantic >= 1.10.22
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0
numpy >=1.17.2, <1.27.0
onnx >=1.12.0, <1.19.0
onnxruntime >=1.12.0, <1.21.0
onnxscript >= 0.2.2, <0.3.0
psutil <7.0.1 # for `DeviceStatsMonitor`
pandas >2.0, <2.4.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")

_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
6 changes: 2 additions & 4 deletions src/lightning/fabric/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import torch
from lightning_utilities.core.imports import RequirementCache, compare_version
from lightning_utilities.core.imports import compare_version
from packaging.version import Version

from lightning.fabric.accelerators import XLAAccelerator
Expand Down Expand Up @@ -112,9 +112,7 @@ def _runif_reasons(
reasons.append("Standalone execution")
kwargs["standalone"] = True

if deepspeed and not (
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
):
if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4):
reasons.append("Deepspeed")

if dynamo:
Expand Down
22 changes: 19 additions & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
Expand Down Expand Up @@ -74,8 +75,10 @@

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.onnx import ONNXProgram

_ONNX_AVAILABLE = RequirementCache("onnx")
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1386,12 +1389,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
)

@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
def to_onnx(
self,
file_path: Union[str, Path, BytesIO, None] = None,
input_sample: Optional[Any] = None,
**kwargs: Any,
) -> Optional["ONNXProgram"]:
"""Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
input_sample: An input for tracing. Default: None (Use self.example_input_array)

**kwargs: Will be passed to torch.onnx.export function.

Example::
Expand All @@ -1412,6 +1421,12 @@ def forward(self, x):
if not _ONNX_AVAILABLE:
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")

if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5):
raise ModuleNotFoundError(
f"`{type(self).__name__}.to_onnx(dynamo=True)` "
"requires `onnxscript` and `torch>=2.5.0` to be installed."
)

mode = self.training

if input_sample is None:
Expand All @@ -1428,8 +1443,9 @@ def forward(self, x):
file_path = str(file_path) if isinstance(file_path, Path) else file_path
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)
return ret

@torch.no_grad()
def to_torchscript(
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
onnxscript: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand All @@ -64,6 +65,7 @@ def _runif_reasons(
psutil: Require that psutil is installed.
sklearn: Require that scikit-learn is installed.
onnx: Require that onnx is installed.
onnxscript: Require that onnxscript is installed.

"""

Expand Down Expand Up @@ -96,4 +98,7 @@ def _runif_reasons(
if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")

if onnxscript and not _ONNXSCRIPT_AVAILABLE:
reasons.append("onnxscript")

return reasons, kwargs
68 changes: 66 additions & 2 deletions tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import operator
import os
import re
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
Expand All @@ -25,6 +26,7 @@

import tests_pytorch.helpers.pipelines as tpipes
from lightning.pytorch import Trainer
from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.utilities.test_model_summary import UnorderedModel
Expand Down Expand Up @@ -139,8 +141,16 @@ def test_error_if_no_input(tmp_path):
model.to_onnx(file_path)


@pytest.mark.parametrize(
"dynamo",
[
None,
pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)),
pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)),
],
)
@RunIf(onnx=True)
def test_if_inference_output_is_valid(tmp_path):
def test_if_inference_output_is_valid(tmp_path, dynamo):
"""Test that the output inferred from ONNX model is same as from PyTorch."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
Expand All @@ -153,7 +163,12 @@ def test_if_inference_output_is_valid(tmp_path):
torch_out = model(model.example_input_array)

file_path = os.path.join(tmp_path, "model.onnx")
model.to_onnx(file_path, model.example_input_array, export_params=True)
kwargs = {
"export_params": True,
}
if dynamo is not None:
kwargs["dynamo"] = dynamo
model.to_onnx(file_path, model.example_input_array, **kwargs)

ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {}
ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs)
Expand All @@ -167,3 +182,52 @@ def to_numpy(tensor):

# compare ONNX Runtime and PyTorch results
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)


@RunIf(min_torch="2.5.0", dynamo=True)
@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.")
def test_model_onnx_export_missing_onnxscript():
"""Test that an error is raised if onnxscript is not available."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

with pytest.raises(
ModuleNotFoundError,
match=re.escape(
f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.",
),
):
model.to_onnx(dynamo=True)


@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True)
def test_model_return_type():
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
model.eval()

onnx_pg = model.to_onnx(dynamo=True)

onnx_cls = torch.onnx.ONNXProgram if torch.__version__ >= "2.6.0" else torch.onnx._internal.exporter.ONNXProgram

assert isinstance(onnx_pg, onnx_cls)

model_ret = model(model.example_input_array)
inf_ret = onnx_pg(model.example_input_array)

assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05)


@RunIf(max_torch="2.5.0")
def test_model_onnx_export_wrong_torch_version():
"""Test that an error is raised if onnxscript is not available."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

with pytest.raises(
ModuleNotFoundError,
match=re.escape(
f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.",
),
):
model.to_onnx(dynamo=True)
Loading