Skip to content

Commit 7b21322

Browse files
peri044gs-olive
andauthored
feat: Implement support for exporting Torch-TensorRT compiled graphs using torch.export serde APIs (#2249)
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
1 parent 7e5d05f commit 7b21322

File tree

13 files changed

+797
-43
lines changed

13 files changed

+797
-43
lines changed

.circleci/config.yml

+17-2
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ commands:
802802
- store_artifacts:
803803
path: /tmp/testlogs
804804

805-
test-dynamo-models_torch_export:
805+
test-dynamo-models_export:
806806
description: "Test the Dynamo models via torch_export path"
807807
steps:
808808
- run:
@@ -818,6 +818,20 @@ commands:
818818
- store_artifacts:
819819
path: /tmp/testlogs
820820

821+
test-dynamo-export_serde:
822+
description: "Test the export serialize/deserialize functionality for Dynamo models"
823+
steps:
824+
- run:
825+
name: Run Dynamo models and test export serde with TRT compiled modules
826+
command: |
827+
cd tests/py/dynamo/models
828+
pytest test_export_serde.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo
829+
830+
- store_test_results:
831+
path: /tmp/artifacts
832+
- store_artifacts:
833+
path: /tmp/testlogs
834+
821835
test-dynamo-converters:
822836
description: "Test the Dynamo aten converters"
823837
steps:
@@ -1122,7 +1136,8 @@ jobs:
11221136
- test-dynamo-backend
11231137
- test-dynamo-shared_utilities
11241138
- test-dynamo-models_torch_compile
1125-
- test-dynamo-models_torch_export
1139+
- test-dynamo-models_export
1140+
- test-dynamo-export_serde
11261141

11271142
package-x86_64-linux:
11281143
parameters:

.github/workflows/build-test.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ jobs:
141141
cd tests/py/dynamo
142142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
143143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144-
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py
144+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
145+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
145146
popd
146147
147148
tests-py-torch-compile-be:

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ User Guide
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`ptq`
4444
* :ref:`runtime`
45+
* :ref:`saving_models`
4546
* :ref:`dynamic_shapes`
4647
* :ref:`use_from_pytorch`
4748
* :ref:`using_dla`
@@ -55,6 +56,7 @@ User Guide
5556
user_guide/getting_started_with_fx_path
5657
user_guide/ptq
5758
user_guide/runtime
59+
user_guide/saving_models
5860
user_guide/dynamic_shapes
5961
user_guide/use_from_pytorch
6062
user_guide/using_dla

docsrc/user_guide/saving_models.rst

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
.. _runtime:
2+
3+
Saving models compiled with Torch-TensorRT
4+
====================================
5+
6+
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
7+
8+
1) Dynamo IR
9+
10+
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
11+
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
12+
13+
a) Converting to Torchscript
14+
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
15+
The following code illustrates this approach.
16+
17+
.. code-block:: python
18+
19+
import torch
20+
import torch_tensorrt
21+
22+
model = MyModel().eval().cuda()
23+
inputs = torch.randn((1, 3, 224, 224)).cuda()
24+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
25+
trt_script_model = torch.jit.trace(trt_gm, inputs)
26+
torch.jit.save(trt_script_model, "trt_model.ts")
27+
28+
# Later, you can load it and run inference
29+
model = torch.jit.load("trt_model.ts").cuda()
30+
model(inputs)
31+
32+
b) ExportedProgram
33+
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
34+
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
35+
36+
.. code-block:: python
37+
38+
import torch
39+
import torch_tensorrt
40+
from torch_tensorrt.dynamo.export import transform, create_exported_program
41+
42+
model = MyModel().eval().cuda()
43+
inputs = torch.randn((1, 3, 224, 224)).cuda()
44+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
45+
# Transform and create an exported program
46+
trt_gm = transform(trt_gm, inputs)
47+
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
48+
torch._export.save(trt_exp_program, "trt_model.ep")
49+
50+
# Later, you can load it and run inference
51+
model = torch._export.load("trt_model.ep")
52+
model(inputs)
53+
54+
`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
55+
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
56+
57+
NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
58+
59+
2) Torchscript IR
60+
61+
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
62+
This behavior stays the same in 2.X versions as well.
63+
64+
.. code-block:: python
65+
66+
import torch
67+
import torch_tensorrt
68+
69+
model = MyModel().eval().cuda()
70+
inputs = torch.randn((1, 3, 224, 224)).cuda()
71+
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
72+
torch.jit.save(trt_ts, "trt_model.ts")
73+
74+
# Later, you can load it and run inference
75+
model = torch.jit.load("trt_model.ts").cuda()
76+
model(inputs)
77+

py/torch_tensorrt/_compile.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -224,14 +224,14 @@ def compile(
224224

225225
# Export the module
226226
torchtrt_inputs = prepare_inputs(input_list)
227-
module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
228-
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
229-
module,
227+
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
228+
trt_graph_module = dynamo_compile(
229+
exp_program,
230230
inputs=torchtrt_inputs,
231231
enabled_precisions=enabled_precisions_set,
232232
**kwargs,
233233
)
234-
return compiled_aten_module
234+
return trt_graph_module
235235
elif target_ir == _IRType.torch_compile:
236236
return torch_compile(
237237
module, enabled_precisions=enabled_precisions_set, **kwargs

py/torch_tensorrt/dynamo/__init__.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._settings import * # noqa: F403
11-
from ._SourceIR import SourceIR # noqa: F403
12-
from .aten_tracer import trace # noqa: F403
13-
from .compile import compile # noqa: F403
14-
from .conversion import * # noqa: F403
15-
from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403
16-
from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403
10+
from ._settings import *
11+
from ._SourceIR import SourceIR
12+
from .aten_tracer import trace
13+
from .compile import compile
14+
from .conversion import *
15+
from .conversion.converter_registry import (
16+
DYNAMO_CONVERTERS,
17+
dynamo_tensorrt_converter,
18+
)

py/torch_tensorrt/dynamo/aten_tracer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import torch
88
from torch._export import dynamic_dim, export
99
from torch_tensorrt._Input import Input
10-
from torch_tensorrt.dynamo._defaults import default_device
10+
from torch_tensorrt.dynamo._defaults import (
11+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
12+
default_device,
13+
)
1114
from torch_tensorrt.dynamo.lowering import get_decompositions
1215
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
1316

@@ -75,14 +78,11 @@ def trace(
7578
trace_inputs.append(torch_inputs[idx])
7679

7780
experimental_decompositions = kwargs.get(
78-
"enable_experimental_decompositions", False
81+
"enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS
7982
)
8083
with unittest.mock.patch(
8184
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
8285
):
83-
graph_module = export(
84-
model, tuple(trace_inputs), constraints=constraints
85-
).module()
86+
exp_program = export(model, tuple(trace_inputs), constraints=constraints)
8687

87-
logger.debug("Post export graph: " + str(graph_module.graph))
88-
return graph_module
88+
return exp_program

py/torch_tensorrt/dynamo/compile.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import logging
45
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
56

67
import torch
78
import torch_tensorrt
9+
from torch.export import ExportedProgram
810
from torch_tensorrt._Device import Device
911
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1012
EngineCapability,
@@ -34,6 +36,7 @@
3436
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
3537
from torch_tensorrt.dynamo.utils import (
3638
get_torch_inputs,
39+
prepare_inputs,
3740
set_log_level,
3841
to_torch_device,
3942
to_torch_tensorrt_device,
@@ -43,7 +46,7 @@
4346

4447

4548
def compile(
46-
gm: Any,
49+
exported_program: ExportedProgram,
4750
inputs: Any,
4851
*,
4952
device: Optional[Union[Device, torch.device, str]] = DEVICE,
@@ -76,24 +79,23 @@ def compile(
7679
if debug:
7780
set_log_level(logger.parent, logging.DEBUG)
7881

82+
if not isinstance(inputs, collections.abc.Sequence):
83+
inputs = [inputs]
84+
85+
# Prepare torch_trt inputs
86+
inputs = prepare_inputs(inputs)
87+
device = to_torch_tensorrt_device(device)
88+
89+
gm = exported_program.module()
90+
logger.debug("Input graph: " + str(gm.graph))
91+
7992
# Apply lowering on the graph module
8093
torch_inputs = get_torch_inputs(inputs, device)
8194
gm = apply_lowering_passes(gm, torch_inputs)
95+
logger.debug("Lowered Input graph: " + str(gm.graph))
8296

8397
enabled_precisions = set(enabled_precisions)
8498

85-
logger.warning(
86-
"The Dynamo backend is an experimental feature, for which only the "
87-
"following arguments are supported: "
88-
"{enabled_precisions, debug, workspace_size, min_block_size, "
89-
"max_aux_streams, version_compatible, optimization_level, "
90-
"torch_executed_ops, pass_through_build_failures, "
91-
"use_fast_partitioner, enable_experimental_decompositions, "
92-
"require_full_compilation}"
93-
)
94-
95-
device = to_torch_tensorrt_device(device)
96-
9799
if (
98100
torch.float16 in enabled_precisions
99101
or torch_tensorrt.dtype.half in enabled_precisions
@@ -207,12 +209,11 @@ def compile_module(
207209
# Iterate over all components that can be accelerated
208210
# Generate the corresponding TRT Module for those
209211
for name, _ in partitioned_module.named_children():
212+
submodule = getattr(partitioned_module, name)
210213
# Criteria for a module to be convertible to TRT
211214
if settings.use_fast_partitioner and "_run_on_acc" not in name:
212215
continue
213216

214-
submodule = getattr(partitioned_module, name)
215-
216217
# Get the submodule inputs for min, opt, max shapes of the graph inputs
217218
submodule_inputs = partitioning.get_submod_inputs(
218219
partitioned_module,
@@ -239,19 +240,19 @@ def compile_module(
239240
name,
240241
)
241242

242-
# Create TRT Module from submodule
243-
trt_mod = convert_module(
243+
# Create TRT engines from submodule
244+
trt_module = convert_module(
244245
submodule,
245246
submodule_inputs,
246247
settings=settings,
247248
name=name,
248249
)
249250

250-
trt_modules[name] = trt_mod
251+
trt_modules[name] = trt_module
251252

252253
# Replace all FX Modules with TRT Modules
253-
for name, trt_mod in trt_modules.items():
254-
setattr(partitioned_module, name, trt_mod)
254+
for name, trt_module in trt_modules.items():
255+
setattr(partitioned_module, name, trt_module)
255256

256257
# Reset settings object to user specification after fallback to global partitioning mode
257258
if fast_partitioner_failed:

0 commit comments

Comments
 (0)