Skip to content

feat: Implement support for exporting Torch-TensorRT compiled graphs using torch.export serde APIs #2249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
cc42ca3
feat: Express TRT engines as nodes instead of modules
peri044 Aug 2, 2023
afcd5ec
chore: Fix input nodes to TRT graph
peri044 Aug 4, 2023
58dcc4f
chore: prototype
peri044 Aug 4, 2023
a57f3c0
chore: minor change
peri044 Aug 7, 2023
f1f202e
feat: Move tracing to use aot export apis
peri044 Aug 8, 2023
abaf047
chore: minor changes
peri044 Aug 9, 2023
370099f
chore: minor change
peri044 Aug 10, 2023
bb1f3cf
chore: minor changes
peri044 Aug 11, 2023
3d05b4d
chore: Rebase with main
peri044 Aug 11, 2023
8d99be5
chore: rebase
peri044 Aug 16, 2023
0aad214
chore: minor logging updates
peri044 Aug 17, 2023
8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
f6969be
Key op fixes for failing tests
gs-olive Aug 5, 2023
bad1594
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
db56dd6
chore: Move to new export APIs
peri044 Aug 17, 2023
bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 Aug 17, 2023
b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
6bd3c64
Key op fixes for failing tests
gs-olive Aug 5, 2023
248073f
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
3b6e1e7
Key op fixes for failing tests
gs-olive Aug 5, 2023
2107d8e
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
fd5a41e
chore: add BERT test case
peri044 Aug 18, 2023
f047651
chore: remove pdb
peri044 Aug 21, 2023
4862c68
chore: rebase with main
peri044 Aug 21, 2023
0ec68e6
chore: rebase with export_prototype branch
peri044 Aug 21, 2023
1a39cae
feat: Express TRTengines as nodes
peri044 Aug 22, 2023
ab76c0d
chore: rebase
peri044 Aug 23, 2023
0cac5ad
chore: refactor
peri044 Aug 24, 2023
e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
9610ba7
Key op fixes for failing tests
gs-olive Aug 5, 2023
e19aae7
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Aug 25, 2023
ae98595
chore: refactor code and add test cases for serde
peri044 Aug 28, 2023
601ff44
chore: Add support for hybrid graph save/load by inlining pytorch sub…
peri044 Aug 28, 2023
1be093f
chore: rebase with export_prototype
peri044 Aug 28, 2023
d73ef1c
chore: minor updates
peri044 Sep 2, 2023
88328c6
chore: minor updates
peri044 Sep 5, 2023
a4251f1
chore: updates
peri044 Sep 8, 2023
3790362
chore: rebase with main
peri044 Sep 18, 2023
20e2a42
chore: updates
peri044 Sep 22, 2023
e588a17
chore: update docs
peri044 Sep 25, 2023
c457813
chore: rebase with main
peri044 Sep 25, 2023
5d33251
chore: uncomment a failing test
peri044 Sep 26, 2023
47822d6
chore: updates
peri044 Sep 26, 2023
16640d0
chore: rebase
peri044 Sep 30, 2023
7522a71
chore: rebase
peri044 Oct 1, 2023
3bcb02d
chore: address review comments
peri044 Oct 2, 2023
07f357c
chore: fix tests
peri044 Oct 2, 2023
b2b6373
chore: revert harness.py changes
peri044 Oct 2, 2023
4d82e17
chore: fix tests
peri044 Oct 2, 2023
200b03f
chore: address review comments
peri044 Oct 2, 2023
52017d2
chore: updates
peri044 Oct 2, 2023
29073c3
chore: updates
peri044 Oct 2, 2023
c4c6e5c
chore: rebase with main
peri044 Oct 3, 2023
fbe929f
chore: fix tests
peri044 Oct 3, 2023
9ea829d
chore: address review comments
peri044 Oct 3, 2023
f24a646
chore: revert fx changes
peri044 Oct 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-models_torch_export:
test-dynamo-models_export:
description: "Test the Dynamo models via torch_export path"
steps:
- run:
Expand All @@ -818,6 +818,20 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-export_serde:
description: "Test the export serialize/deserialize functionality for Dynamo models"
steps:
- run:
name: Run Dynamo models and test export serde with TRT compiled modules
command: |
cd tests/py/dynamo/models
pytest test_export_serde.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-converters:
description: "Test the Dynamo aten converters"
steps:
Expand Down Expand Up @@ -1122,7 +1136,8 @@ jobs:
- test-dynamo-backend
- test-dynamo-shared_utilities
- test-dynamo-models_torch_compile
- test-dynamo-models_torch_export
- test-dynamo-models_export
- test-dynamo-export_serde

package-x86_64-linux:
parameters:
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
popd

tests-py-torch-compile-be:
Expand Down
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ User Guide
* :ref:`getting_started_with_fx`
* :ref:`ptq`
* :ref:`runtime`
* :ref:`saving_models`
* :ref:`dynamic_shapes`
* :ref:`use_from_pytorch`
* :ref:`using_dla`
Expand All @@ -55,6 +56,7 @@ User Guide
user_guide/getting_started_with_fx_path
user_guide/ptq
user_guide/runtime
user_guide/saving_models
user_guide/dynamic_shapes
user_guide/use_from_pytorch
user_guide/using_dla
Expand Down
77 changes: 77 additions & 0 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _runtime:

Saving models compiled with Torch-TensorRT
====================================

Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.

1) Dynamo IR

Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects

a) Converting to Torchscript
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
The following code illustrates this approach.

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_script_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_script_model, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)

b) ExportedProgram
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.

.. code-block:: python

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.export import transform, create_exported_program

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_gm = transform(trt_gm, inputs)
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
torch._export.save(trt_exp_program, "trt_model.ep")

# Later, you can load it and run inference
model = torch._export.load("trt_model.ep")
model(inputs)

`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341

2) Torchscript IR

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)

8 changes: 4 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ def compile(

# Export the module
torchtrt_inputs = prepare_inputs(input_list)
module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
module,
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
trt_graph_module = dynamo_compile(
exp_program,
inputs=torchtrt_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
return compiled_aten_module
return trt_graph_module
elif target_ir == _IRType.torch_compile:
return torch_compile(
module, enabled_precisions=enabled_precisions_set, **kwargs
Expand Down
16 changes: 9 additions & 7 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._settings import * # noqa: F403
from ._SourceIR import SourceIR # noqa: F403
from .aten_tracer import trace # noqa: F403
from .compile import compile # noqa: F403
from .conversion import * # noqa: F403
from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403
from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403
from ._settings import *
from ._SourceIR import SourceIR
from .aten_tracer import trace
from .compile import compile
from .conversion import *
from .conversion.converter_registry import (
DYNAMO_CONVERTERS,
dynamo_tensorrt_converter,
)
14 changes: 7 additions & 7 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch
from torch._export import dynamic_dim, export
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._defaults import (
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
default_device,
)
from torch_tensorrt.dynamo.lowering import get_decompositions
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device

Expand Down Expand Up @@ -75,14 +78,11 @@ def trace(
trace_inputs.append(torch_inputs[idx])

experimental_decompositions = kwargs.get(
"enable_experimental_decompositions", False
"enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS
)
with unittest.mock.patch(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(
model, tuple(trace_inputs), constraints=constraints
).module()
exp_program = export(model, tuple(trace_inputs), constraints=constraints)

logger.debug("Post export graph: " + str(graph_module.graph))
return graph_module
return exp_program
41 changes: 21 additions & 20 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import collections.abc
import logging
from typing import Any, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
from torch.export import ExportedProgram
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
Expand Down Expand Up @@ -34,6 +36,7 @@
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
from torch_tensorrt.dynamo.utils import (
get_torch_inputs,
prepare_inputs,
set_log_level,
to_torch_device,
to_torch_tensorrt_device,
Expand All @@ -43,7 +46,7 @@


def compile(
gm: Any,
exported_program: ExportedProgram,
inputs: Any,
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
Expand Down Expand Up @@ -76,24 +79,23 @@ def compile(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

# Prepare torch_trt inputs
inputs = prepare_inputs(inputs)
device = to_torch_tensorrt_device(device)

gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)
logger.debug("Lowered Input graph: " + str(gm.graph))

enabled_precisions = set(enabled_precisions)

logger.warning(
"The Dynamo backend is an experimental feature, for which only the "
"following arguments are supported: "
"{enabled_precisions, debug, workspace_size, min_block_size, "
"max_aux_streams, version_compatible, optimization_level, "
"torch_executed_ops, pass_through_build_failures, "
"use_fast_partitioner, enable_experimental_decompositions, "
"require_full_compilation}"
)

device = to_torch_tensorrt_device(device)

if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
Expand Down Expand Up @@ -207,12 +209,11 @@ def compile_module(
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

submodule = getattr(partitioned_module, name)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
Expand All @@ -239,19 +240,19 @@ def compile_module(
name,
)

# Create TRT Module from submodule
trt_mod = convert_module(
# Create TRT engines from submodule
trt_module = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod
trt_modules[name] = trt_module

# Replace all FX Modules with TRT Modules
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
Expand Down
Loading