-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Torch-Tensorrt Integration with LightningModule #20808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
GdoongMathew
wants to merge
23
commits into
Lightning-AI:master
Choose a base branch
from
GdoongMathew:feat/to_tensorrt
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+268
−2
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
f589e31
feat: add `to_tensorrt` in the `LightningModule`.
GdoongMathew 14b9f29
feat: add `to_tensorrt` in the `LightningModule`.
GdoongMathew 314c463
refactor: fix `to_tensorrt` impl
GdoongMathew 534c6c4
test: add test_torch_tensorrt.py
GdoongMathew e8a7fd2
Merge remote-tracking branch 'origin/feat/to_tensorrt' into feat/to_t…
GdoongMathew 5c01acb
add dependency in test requirement.
GdoongMathew 26d2788
update dependency in test requirement.
GdoongMathew d42ebe4
fix mypy error.
GdoongMathew 958968d
fix: fix unittest.
GdoongMathew 87a048b
Merge branch 'master' into feat/to_tensorrt
GdoongMathew 53257f9
fix: fix unittest.
GdoongMathew f076233
Merge remote-tracking branch 'origin/feat/to_tensorrt' into feat/to_t…
GdoongMathew 0723071
fix: fix type annotation.
GdoongMathew c97ae0d
fix: fix runif tensorrt logic.
GdoongMathew 484f9ce
test: add test `test_missing_tensorrt_package`.
GdoongMathew b35c60c
req: remove mac from the tensorrt dependency.
GdoongMathew e6f8097
Merge branch 'master' into feat/to_tensorrt
GdoongMathew a9047fe
fix: fix default device logics.
GdoongMathew a36907f
test: add test `test_tensorrt_with_wrong_default_device`.
GdoongMathew 937dedf
fix: reorder the import sequence.
GdoongMathew 15f1961
feat: add exception when torch is below 2.2.0.
GdoongMathew 67546dc
add unittests.
GdoongMathew dcac65d
Merge branch 'master' into feat/to_tensorrt
GdoongMathew File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
import re | ||
from io import BytesIO | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
|
||
import tests_pytorch.helpers.pipelines as pipes | ||
from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE | ||
from lightning.pytorch.demos.boring_classes import BoringModel | ||
from lightning.pytorch.utilities.exceptions import MisconfigurationException | ||
from tests_pytorch.helpers.runif import RunIf | ||
|
||
|
||
@RunIf(max_torch="2.2.0") | ||
def test_torch_minimum_version(): | ||
model = BoringModel() | ||
with pytest.raises( | ||
MisconfigurationException, | ||
match=re.escape(f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."), | ||
): | ||
model.to_tensorrt("model.trt") | ||
|
||
|
||
@pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.") | ||
def test_missing_tensorrt_package(): | ||
model = BoringModel() | ||
with pytest.raises( | ||
ModuleNotFoundError, | ||
match=re.escape(f"`{type(model).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "), | ||
): | ||
model.to_tensorrt("model.trt") | ||
|
||
|
||
@RunIf(tensorrt=True, min_torch="2.2.0") | ||
def test_tensorrt_with_wrong_default_device(tmp_path): | ||
model = BoringModel() | ||
input_sample = torch.randn((1, 32)) | ||
file_path = os.path.join(tmp_path, "model.trt") | ||
with pytest.raises(MisconfigurationException): | ||
model.to_tensorrt(file_path, input_sample, default_device="cpu") | ||
|
||
|
||
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") | ||
def test_tensorrt_saves_with_input_sample(tmp_path): | ||
model = BoringModel() | ||
ori_device = model.device | ||
input_sample = torch.randn((1, 32)) | ||
|
||
file_path = os.path.join(tmp_path, "model.trt") | ||
model.to_tensorrt(file_path, input_sample) | ||
|
||
assert os.path.isfile(file_path) | ||
assert os.path.getsize(file_path) > 4e2 | ||
assert model.device == ori_device | ||
|
||
file_path = Path(tmp_path) / "model.trt" | ||
model.to_tensorrt(file_path, input_sample) | ||
assert os.path.isfile(file_path) | ||
assert os.path.getsize(file_path) > 4e2 | ||
assert model.device == ori_device | ||
|
||
file_path = BytesIO() | ||
model.to_tensorrt(file_path, input_sample) | ||
assert len(file_path.getvalue()) > 4e2 | ||
|
||
|
||
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") | ||
def test_tensorrt_error_if_no_input(tmp_path): | ||
model = BoringModel() | ||
model.example_input_array = None | ||
file_path = os.path.join(tmp_path, "model.trt") | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match=r"Could not export to TensorRT since neither `input_sample` nor " | ||
r"`model.example_input_array` attribute is set.", | ||
): | ||
model.to_tensorrt(file_path) | ||
|
||
|
||
@RunIf(tensorrt=True, min_cuda_gpus=2, min_torch="2.2.0") | ||
def test_tensorrt_saves_on_multi_gpu(tmp_path): | ||
trainer_options = { | ||
"default_root_dir": tmp_path, | ||
"max_epochs": 1, | ||
"limit_train_batches": 10, | ||
"limit_val_batches": 10, | ||
"accelerator": "gpu", | ||
"devices": [0, 1], | ||
"strategy": "ddp_spawn", | ||
"enable_progress_bar": False, | ||
} | ||
|
||
model = BoringModel() | ||
model.example_input_array = torch.randn((4, 32)) | ||
|
||
pipes.run_model_test(trainer_options, model, min_acc=0.08) | ||
|
||
file_path = os.path.join(tmp_path, "model.trt") | ||
model.to_tensorrt(file_path) | ||
|
||
assert os.path.exists(file_path) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("ir", "export_type"), | ||
[ | ||
("default", torch.fx.GraphModule), | ||
("dynamo", torch.fx.GraphModule), | ||
("ts", torch.jit.ScriptModule), | ||
], | ||
) | ||
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") | ||
def test_tensorrt_save_ir_type(ir, export_type): | ||
model = BoringModel() | ||
model.example_input_array = torch.randn((4, 32)) | ||
|
||
ret = model.to_tensorrt(ir=ir) | ||
assert isinstance(ret, export_type) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"output_format", | ||
["exported_program", "torchscript"], | ||
) | ||
@pytest.mark.parametrize( | ||
"ir", | ||
["default", "dynamo", "ts"], | ||
) | ||
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") | ||
def test_tensorrt_export_reload(output_format, ir, tmp_path): | ||
if ir == "ts" and output_format == "exported_program": | ||
pytest.skip("TorchScript cannot be exported as exported_program") | ||
|
||
import torch_tensorrt | ||
|
||
model = BoringModel() | ||
model.cuda().eval() | ||
model.example_input_array = torch.ones((4, 32)) | ||
|
||
file_path = os.path.join(tmp_path, "model.trt") | ||
model.to_tensorrt(file_path, output_format=output_format, ir=ir) | ||
|
||
loaded_model = torch_tensorrt.load(file_path) | ||
if output_format == "exported_program": | ||
loaded_model = loaded_model.module() | ||
|
||
with torch.no_grad(), torch.inference_mode(): | ||
model_output = model(model.example_input_array.to("cuda")) | ||
jit_output = loaded_model(model.example_input_array.to("cuda")) | ||
|
||
assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, you can also raise an exception for some older PT versions if desired
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion!! I've added it in 15f1961