Skip to content

Commit bc6e1fa

Browse files
authored
chore: Switch to new export apis (#2376)
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 460a6e2 commit bc6e1fa

File tree

7 files changed

+120
-113
lines changed

7 files changed

+120
-113
lines changed

.github/workflows/build-test.yml

+27-1
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,36 @@ jobs:
141141
cd tests/py/dynamo
142142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
143143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144-
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
145144
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
146145
popd
147146
147+
tests-py-dynamo-serde:
148+
name: Test dynamo export serde [Python]
149+
needs: [generate-matrix, build]
150+
strategy:
151+
fail-fast: false
152+
matrix:
153+
include:
154+
- repository: pytorch/tensorrt
155+
package-name: torch_tensorrt
156+
pre-script: packaging/pre_build_script.sh
157+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
158+
with:
159+
job-name: tests-py-dynamo-serde
160+
repository: "pytorch/tensorrt"
161+
ref: ""
162+
test-infra-repository: pytorch/test-infra
163+
test-infra-ref: main
164+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
165+
pre-script: ${{ matrix.pre-script }}
166+
script: |
167+
export USE_HOST_DEPS=1
168+
pushd .
169+
cd tests/py/dynamo
170+
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
171+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
172+
popd
173+
148174
tests-py-torch-compile-be:
149175
name: Test torch compile backend [Python]
150176
needs: [generate-matrix, build]

py/torch_tensorrt/_Input.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class _ShapeMode(Enum):
4747
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
4848
torch_dtype: torch.dtype = torch.float32
4949
torch_tensor: torch.Tensor = None
50+
name: str = ""
5051

5152
def __init__(self, *args: Any, **kwargs: Any) -> None:
5253
"""__init__ Method for torch_tensorrt.Input
@@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6869
format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
6970
tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
7071
Note: Entering "None" (or not specifying) will set the bound to [0, 2)
71-
72+
torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input.
73+
name (str, optional): Name of this input in the input nn.Module's forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer.
7274
Examples:
7375
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
7476
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
@@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
180182
else:
181183
self.torch_tensor = self.example_tensor()
182184

185+
if "name" in kwargs:
186+
self.name = kwargs["name"]
187+
183188
def __str__(self) -> str:
184189
if self.shape_mode == Input._ShapeMode.STATIC:
185190
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(

py/torch_tensorrt/dynamo/_compiler.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
convert_module,
3535
repair_long_or_double_inputs,
3636
)
37-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
37+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
3838
from torch_tensorrt.dynamo.utils import (
3939
get_torch_inputs,
4040
prepare_inputs,
@@ -146,6 +146,13 @@ def compile(
146146
inputs = prepare_inputs(inputs)
147147
device = to_torch_tensorrt_device(device)
148148

149+
if not isinstance(exported_program, ExportedProgram):
150+
raise AssertionError(
151+
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
152+
)
153+
exported_program = exported_program.run_decompositions(
154+
get_decompositions(enable_experimental_decompositions)
155+
)
149156
gm = exported_program.module()
150157
logger.debug("Input graph: " + str(gm.graph))
151158

py/torch_tensorrt/dynamo/_exporter.py

+7
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def create_trt_exp_program(
229229
"""
230230
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
231231
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
232+
assert output_nodes
233+
output_nodes = output_nodes[0].args[0]
232234

233235
input_specs = [
234236
InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target)
@@ -276,6 +278,7 @@ def inline_trt_modules(
276278
(trt_module_node.args, trt_module.engine),
277279
)
278280
trt_node.meta["val"] = []
281+
assert num_outputs > 0
279282
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
280283
for idx in range(num_outputs):
281284
trt_node.meta["val"].append(
@@ -292,12 +295,16 @@ def inline_trt_modules(
292295
# Insert getitem nodes as outputs (for export serialization to work)
293296
with gm.graph.inserting_after(trt_node):
294297
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
298+
getitem_output.meta["val"] = trt_node.meta["val"]
295299
trt_module_node.replace_all_uses_with(getitem_output)
296300
else:
297301
# Multiple outputs case:
298302
# Replace uses of submodule with the trt_node.
299303
# getitem nodes are already added inherently by the partitioner
300304
trt_module_node.replace_all_uses_with(trt_node)
305+
getitem_nodes = trt_node.users
306+
for idx, getitem_node in enumerate(getitem_nodes):
307+
getitem_node.meta["val"] = trt_node.meta["val"][idx]
301308

302309
# Erase the TRT submodule (call_module) node.
303310
gm.graph.erase_node(trt_module_node)

py/torch_tensorrt/dynamo/_tracer.py

+23-62
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,20 @@
11
from __future__ import annotations
22

33
import logging
4-
import unittest.mock
5-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, Tuple
65

76
import torch
8-
from torch._export import dynamic_dim, export
9-
from torch_tensorrt._Device import Device
7+
from torch.export import Dim, export
108
from torch_tensorrt._Input import Input
11-
from torch_tensorrt.dynamo._defaults import (
12-
DEBUG,
13-
DEVICE,
14-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
15-
default_device,
16-
)
17-
from torch_tensorrt.dynamo.lowering import get_decompositions
9+
from torch_tensorrt.dynamo._defaults import DEBUG, default_device
1810
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
1911

2012
logger = logging.getLogger(__name__)
2113

2214

23-
def get_random_tensor(
24-
shape: List[Any], dtype: torch.dtype, device: torch.device
25-
) -> torch.Tensor:
26-
if dtype == torch.int32 or dtype == torch.int64:
27-
return torch.randint(2, 10, shape, dtype=dtype, device=device)
28-
elif dtype in (torch.float64, torch.float32, torch.float16):
29-
return torch.randn(shape, dtype=dtype, device=device)
30-
else:
31-
logger.critical(
32-
"Invalid dtype detected in creating input tensors for tracing the graph."
33-
)
34-
raise
35-
36-
3715
def trace(
3816
mod: torch.nn.Module | torch.fx.GraphModule,
3917
inputs: Tuple[Any, ...],
40-
device: Optional[Union[Device, torch.device, str]] = DEVICE,
41-
debug: bool = DEBUG,
42-
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
4318
**kwargs: Any,
4419
) -> torch.export.ExportedProgram:
4520
"""Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT
@@ -65,9 +40,9 @@ def trace(
6540
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
6641
]
6742
Keyword Arguments:
68-
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
43+
device (Union(torch.device, dict)): Target device for TensorRT engines to run on ::
6944
70-
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
45+
device=torch.device("cuda:0")
7146
7247
debug (bool): Enable debuggable engine
7348
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
@@ -77,50 +52,36 @@ def trace(
7752
"""
7853

7954
# Set log level at the top of compilation (torch_tensorrt.dynamo)
55+
debug = kwargs.get("debug", DEBUG)
8056
if debug:
8157
set_log_level(logger.parent, logging.DEBUG)
82-
device = to_torch_device(device if device else default_device())
8358

84-
# Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT
85-
# Torch dynamo does not allow 0/1 value for dynamic dimensions
86-
# for inputs during tracing. Hence we create new inputs for export
59+
device = to_torch_device(kwargs.get("device", default_device()))
8760
torch_inputs = get_torch_inputs(inputs, device)
88-
trace_inputs = []
89-
constraints = []
90-
for idx, input in enumerate(inputs):
91-
if input.shape_mode == Input._ShapeMode.DYNAMIC:
61+
dynamic_shapes = {}
62+
for input in inputs:
63+
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
64+
if not input.name:
65+
raise AssertionError(
66+
f"Expected a name for a dynamic input with shape {input.shape} but found none"
67+
)
9268
min_shape = input.shape["min_shape"]
9369
opt_shape = input.shape["opt_shape"]
9470
max_shape = input.shape["max_shape"]
9571
assert len(min_shape) == len(opt_shape) == len(max_shape)
96-
97-
constraint_dims = []
98-
new_shape = []
72+
dynamic_dims = {}
9973
for dim in range(len(min_shape)):
10074
if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
101-
new_shape.append(torch_inputs[idx].shape[dim])
75+
continue
10276
else:
103-
constraint_dims.append(dim)
104-
if torch_inputs[idx].shape[dim] == 1:
105-
new_shape.append(torch_inputs[idx].shape[dim] + 1)
106-
else:
107-
new_shape.append(torch_inputs[idx].shape[dim])
108-
109-
trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device)
77+
dynamic_dims[dim] = Dim(
78+
input.name + "_" + str(dim),
79+
min=min_shape[dim],
80+
max=max_shape[dim],
81+
)
11082

111-
for dim in constraint_dims:
112-
if min_shape[dim] > 1:
113-
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
114-
if max_shape[dim] > 1:
115-
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])
116-
trace_inputs.append(trace_input)
117-
else:
118-
trace_inputs.append(torch_inputs[idx])
83+
dynamic_shapes[input.name] = dynamic_dims
11984

120-
with unittest.mock.patch(
121-
"torch._export.DECOMP_TABLE",
122-
get_decompositions(enable_experimental_decompositions),
123-
):
124-
exp_program = export(mod, tuple(trace_inputs), constraints=constraints)
85+
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes)
12586

12687
return exp_program

tests/py/dynamo/models/test_dyn_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def forward(self, x):
3636
opt_shape=(4, 3, 224, 224),
3737
max_shape=(8, 3, 224, 224),
3838
dtype=torch.float32,
39+
name="x",
3940
)
4041
],
4142
"device": torchtrt.Device("cuda:0"),
@@ -88,6 +89,7 @@ def forward(self, x):
8889
opt_shape=(4, 3, 224, 224),
8990
max_shape=(8, 3, 224, 224),
9091
dtype=torch.float32,
92+
name="x",
9193
)
9294
],
9395
"device": torchtrt.Device("cuda:0"),

0 commit comments

Comments
 (0)