Skip to content

Torch-Tensorrt Integration with LightningModule #20808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f589e31
feat: add `to_tensorrt` in the `LightningModule`.
GdoongMathew May 4, 2025
14b9f29
feat: add `to_tensorrt` in the `LightningModule`.
GdoongMathew May 4, 2025
314c463
refactor: fix `to_tensorrt` impl
GdoongMathew May 10, 2025
534c6c4
test: add test_torch_tensorrt.py
GdoongMathew May 10, 2025
e8a7fd2
Merge remote-tracking branch 'origin/feat/to_tensorrt' into feat/to_t…
GdoongMathew May 10, 2025
5c01acb
add dependency in test requirement.
GdoongMathew May 10, 2025
26d2788
update dependency in test requirement.
GdoongMathew May 10, 2025
d42ebe4
fix mypy error.
GdoongMathew May 21, 2025
958968d
fix: fix unittest.
GdoongMathew May 21, 2025
87a048b
Merge branch 'master' into feat/to_tensorrt
GdoongMathew May 21, 2025
53257f9
fix: fix unittest.
GdoongMathew May 23, 2025
f076233
Merge remote-tracking branch 'origin/feat/to_tensorrt' into feat/to_t…
GdoongMathew May 23, 2025
0723071
fix: fix type annotation.
GdoongMathew May 23, 2025
c97ae0d
fix: fix runif tensorrt logic.
GdoongMathew Jun 7, 2025
484f9ce
test: add test `test_missing_tensorrt_package`.
GdoongMathew Jun 7, 2025
b35c60c
req: remove mac from the tensorrt dependency.
GdoongMathew Jun 7, 2025
e6f8097
Merge branch 'master' into feat/to_tensorrt
GdoongMathew Jun 7, 2025
a9047fe
fix: fix default device logics.
GdoongMathew Jun 7, 2025
a36907f
test: add test `test_tensorrt_with_wrong_default_device`.
GdoongMathew Jun 7, 2025
937dedf
fix: reorder the import sequence.
GdoongMathew Jun 7, 2025
15f1961
feat: add exception when torch is below 2.2.0.
GdoongMathew Jun 13, 2025
67546dc
add unittests.
GdoongMathew Jun 13, 2025
dcac65d
Merge branch 'master' into feat/to_tensorrt
GdoongMathew Jun 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in
uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App

tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger`

torch-tensorrt >=1.4.0, <2.8.0; platform_system != "Darwin"
108 changes: 107 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
"""The LightningModule - an nn.Module with many additional features."""

import copy
import logging
import numbers
import weakref
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from io import BytesIO
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -47,6 +48,7 @@
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
Expand Down Expand Up @@ -76,6 +78,7 @@
from torch.distributed.device_mesh import DeviceMesh

_ONNX_AVAILABLE = RequirementCache("onnx")
_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt")

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1519,6 +1522,109 @@ def forward(self, x):

return torchscript_module

@torch.no_grad()
def to_tensorrt(
self,
file_path: Optional[Union[str, Path, BytesIO]] = None,
input_sample: Optional[Any] = None,
ir: Literal["default", "dynamo", "ts"] = "default",
output_format: Literal["exported_program", "torchscript"] = "exported_program",
retrace: bool = False,
default_device: Union[str, torch.device] = "cuda",
**compile_kwargs: Any,
) -> Union[ScriptModule, torch.fx.GraphModule]:
"""Export the model to ScriptModule or GraphModule using TensorRT compile backend.

Args:
file_path: Path where to save the tensorrt model. Default: None (no file saved).
input_sample: inputs to be used during `torch_tensorrt.compile`.
Default: None (Use :attr:`example_input_array`).
ir: The IR mode to use for TensorRT compilation. Default: "default".
output_format: The format of the output model. Default: "exported_program".
retrace: Whether to retrace the model. Default: False.
default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda".
**compile_kwargs: Additional arguments that will be passed to the TensorRT compile function.

Example::

class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)

model = SimpleModel()
input_sample = torch.randn(1, 64)
exported_program = model.to_tensorrt(
file_path="export.ep",
inputs=input_sample,
)

"""
if not _TORCH_GREATER_EQUAL_2_2:
raise MisconfigurationException(
f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."
)

if not _TORCH_TRT_AVAILABLE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you can also raise an exception for some older PT versions if desired

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!! I've added it in 15f1961

raise ModuleNotFoundError(
f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "
)

mode = self.training
device = self.device
if self.device.type != "cuda":
default_device = torch.device(default_device) if isinstance(default_device, str) else default_device

if not torch.cuda.is_available() or default_device.type != "cuda":
raise MisconfigurationException(
f"TensorRT only supports CUDA devices. The current device is {self.device}."
f" Please set the `default_device` argument to a CUDA device."
)

self.to(default_device)

if input_sample is None:
if self.example_input_array is None:
raise ValueError(
"Could not export to TensorRT since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
input_sample = self.example_input_array

import torch_tensorrt

input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample)
input_sample = self._on_before_batch_transfer(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)

with _jit_is_scripting() if ir == "ts" else nullcontext():
trt_obj = torch_tensorrt.compile(
module=self.eval(),
ir=ir,
inputs=input_sample,
**compile_kwargs,
)
self.train(mode)
self.to(device)

if file_path is not None:
if ir == "ts" and output_format != "torchscript":
raise ValueError(
"TensorRT with IR mode 'ts' only supports output format 'torchscript'."
f" The current output format is {output_format}."
)
torch_tensorrt.save(
trt_obj,
file_path,
inputs=input_sample,
output_format=output_format,
retrace=retrace,
)
return trt_obj

@_restricted_classmethod
def load_from_checkpoint(
cls,
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _TORCH_TRT_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
tensorrt: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand Down Expand Up @@ -96,4 +97,7 @@ def _runif_reasons(
if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")

if tensorrt and not _TORCH_TRT_AVAILABLE:
reasons.append("torch-tensorrt")

return reasons, kwargs
154 changes: 154 additions & 0 deletions tests/tests_pytorch/models/test_torch_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import re
from io import BytesIO
from pathlib import Path

import pytest
import torch

import tests_pytorch.helpers.pipelines as pipes
from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf


@RunIf(max_torch="2.2.0")
def test_torch_minimum_version():
model = BoringModel()
with pytest.raises(
MisconfigurationException,
match=re.escape(f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."),
):
model.to_tensorrt("model.trt")


@pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.")
def test_missing_tensorrt_package():
model = BoringModel()
with pytest.raises(
ModuleNotFoundError,
match=re.escape(f"`{type(model).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "),
):
model.to_tensorrt("model.trt")


@RunIf(tensorrt=True, min_torch="2.2.0")
def test_tensorrt_with_wrong_default_device(tmp_path):
model = BoringModel()
input_sample = torch.randn((1, 32))
file_path = os.path.join(tmp_path, "model.trt")
with pytest.raises(MisconfigurationException):
model.to_tensorrt(file_path, input_sample, default_device="cpu")


@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
def test_tensorrt_saves_with_input_sample(tmp_path):
model = BoringModel()
ori_device = model.device
input_sample = torch.randn((1, 32))

file_path = os.path.join(tmp_path, "model.trt")
model.to_tensorrt(file_path, input_sample)

assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2
assert model.device == ori_device

file_path = Path(tmp_path) / "model.trt"
model.to_tensorrt(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2
assert model.device == ori_device

file_path = BytesIO()
model.to_tensorrt(file_path, input_sample)
assert len(file_path.getvalue()) > 4e2


@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
def test_tensorrt_error_if_no_input(tmp_path):
model = BoringModel()
model.example_input_array = None
file_path = os.path.join(tmp_path, "model.trt")

with pytest.raises(
ValueError,
match=r"Could not export to TensorRT since neither `input_sample` nor "
r"`model.example_input_array` attribute is set.",
):
model.to_tensorrt(file_path)


@RunIf(tensorrt=True, min_cuda_gpus=2, min_torch="2.2.0")
def test_tensorrt_saves_on_multi_gpu(tmp_path):
trainer_options = {
"default_root_dir": tmp_path,
"max_epochs": 1,
"limit_train_batches": 10,
"limit_val_batches": 10,
"accelerator": "gpu",
"devices": [0, 1],
"strategy": "ddp_spawn",
"enable_progress_bar": False,
}

model = BoringModel()
model.example_input_array = torch.randn((4, 32))

pipes.run_model_test(trainer_options, model, min_acc=0.08)

file_path = os.path.join(tmp_path, "model.trt")
model.to_tensorrt(file_path)

assert os.path.exists(file_path)


@pytest.mark.parametrize(
("ir", "export_type"),
[
("default", torch.fx.GraphModule),
("dynamo", torch.fx.GraphModule),
("ts", torch.jit.ScriptModule),
],
)
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
def test_tensorrt_save_ir_type(ir, export_type):
model = BoringModel()
model.example_input_array = torch.randn((4, 32))

ret = model.to_tensorrt(ir=ir)
assert isinstance(ret, export_type)


@pytest.mark.parametrize(
"output_format",
["exported_program", "torchscript"],
)
@pytest.mark.parametrize(
"ir",
["default", "dynamo", "ts"],
)
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
def test_tensorrt_export_reload(output_format, ir, tmp_path):
if ir == "ts" and output_format == "exported_program":
pytest.skip("TorchScript cannot be exported as exported_program")

import torch_tensorrt

model = BoringModel()
model.cuda().eval()
model.example_input_array = torch.ones((4, 32))

file_path = os.path.join(tmp_path, "model.trt")
model.to_tensorrt(file_path, output_format=output_format, ir=ir)

loaded_model = torch_tensorrt.load(file_path)
if output_format == "exported_program":
loaded_model = loaded_model.module()

with torch.no_grad(), torch.inference_mode():
model_output = model(model.example_input_array.to("cuda"))
jit_output = loaded_model(model.example_input_array.to("cuda"))

assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06)
Loading