Skip to content

Commit ce063a2

Browse files
committed
chore: update test case
1 parent fdbd3d8 commit ce063a2

File tree

6 files changed

+18
-75
lines changed

6 files changed

+18
-75
lines changed

py/torch_tensorrt/_compile.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from torch_tensorrt._features import ENABLED_FEATURES
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15+
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
16+
WrapperTorchTensorRTModule,
17+
)
1518
from torch_tensorrt.fx import InputTensorSpec
1619
from torch_tensorrt.fx.lower import compile as fx_compile
1720
from torch_tensorrt.fx.utils import LowerPrecision
@@ -586,14 +589,16 @@ def save(
586589
Save the model to disk in the specified output format.
587590
588591
Arguments:
589-
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
592+
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module
590593
inputs (torch.Tensor): Torch input tensors
591594
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
592595
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
593596
output_format (str): Format to save the model. Options include exported_program | torchscript.
594597
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
595598
This flag is experimental for now.
596599
"""
600+
if isinstance(module, WrapperTorchTensorRTModule):
601+
module = module.original_module
597602
module_type = _parse_module_type(module)
598603
accepted_formats = {"exported_program", "torchscript"}
599604
if arg_inputs is not None and not all(

py/torch_tensorrt/dynamo/_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
835835

836836
dryrun_stats_display(dryrun_tracker, settings.dryrun)
837837

838-
if len(trt_modules) > 1:
838+
if len(dryrun_tracker.to_run_in_torch) > 0:
839839
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840840
partitioned_module = WrapperTorchTensorRTModule(
841841
partitioned_module,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
7979
self._caller_stream: Optional[torch.cuda.Stream] = None
8080
self._engine_stream: Optional[torch.cuda.Stream] = None
81+
8182
# TODO: Make the below a Dictionary {shape: cudagraph}
8283
self.shape_key: Optional[str] = None
8384

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
if "_run_on_acc" in name:
4747
rt_mod.set_cudagraphs_enabled_parent_module(True)
4848

49-
# TODO: check if only torch needs warm up.
49+
# Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs
5050
with unset_fake_temporarily():
5151
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
5252
s = torch.cuda.Stream()
@@ -256,7 +256,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
256256
self._caller_stream.wait_stream(self._engine_stream)
257257

258258
if cudagraphs_enabled:
259-
# TODO: submodule to return list only
260259
if isinstance(self._output_buffers, (list, tuple)):
261260
output_buffers = self._output_buffers
262261
else:

py/torch_tensorrt/runtime/_weight_streaming.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import torch
55
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
6+
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
7+
WrapperTorchTensorRTModule,
8+
)
69

710
logger = logging.getLogger(__name__)
811

@@ -12,9 +15,14 @@ class _WeightStreamingContextManager(object):
1215
Helper class used to setup weight streaming budget
1316
"""
1417

15-
def __init__(self, module: torch.fx.GraphModule) -> None:
18+
def __init__(
19+
self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule
20+
) -> None:
1621
rt_mods = []
1722
self.current_device_budget = 0
23+
24+
if isinstance(module, WrapperTorchTensorRTModule):
25+
module = module.original_module
1826
for name, rt_mod in module.named_children():
1927
if "_run_on_acc" in name and isinstance(
2028
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)

tests/py/dynamo/runtime/test_002_cudagraphs_py.py

-70
Original file line numberDiff line numberDiff line change
@@ -158,76 +158,6 @@ def forward(self, x):
158158
msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})",
159159
)
160160

161-
def test_cudagraphs_dynamic_py(self):
162-
class SampleModel(torch.nn.Module):
163-
def forward(self, x):
164-
return torch.relu((x + 2) * 0.5)
165-
166-
# TODO: more dynamic dim
167-
# TODO: multiple output
168-
# TODO: module that graph cannot be used
169-
inputs = torch_tensorrt.Input(
170-
min_shape=(1, 3, 224, 224),
171-
opt_shape=(8, 3, 224, 224),
172-
max_shape=(16, 3, 224, 224),
173-
dtype=torch.float,
174-
name="x",
175-
)
176-
fx_graph = torch.fx.symbolic_trace(SampleModel())
177-
178-
# Validate that the results between Torch and Torch-TRT are similar
179-
optimized_model = torch_tensorrt.compile(
180-
fx_graph,
181-
"dynamo",
182-
inputs,
183-
min_block_size=1,
184-
pass_through_build_failures=True,
185-
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
186-
use_python_runtime=True,
187-
)
188-
189-
result_samples = []
190-
torch_results_samples = []
191-
192-
inputs = []
193-
for i in [1, 3, 8, 11, 16]:
194-
inputs.append(torch.randn((i, 3, 224, 224)).cuda())
195-
196-
for n in range(len(inputs) * TRIALS):
197-
i = n // TRIALS
198-
# disable cuda graph at all index for all trials
199-
if n % TRIALS == n // TRIALS:
200-
torch_tensorrt.runtime.set_cudagraphs_mode(False)
201-
else:
202-
torch_tensorrt.runtime.set_cudagraphs_mode(True)
203-
204-
result_samples.append(optimized_model(inputs[i]).detach().cpu())
205-
torch_results_samples.append(fx_graph(inputs[i]).detach().cpu())
206-
207-
for n in range(len(inputs) * TRIALS):
208-
i = n // TRIALS
209-
# enable cuda graph at all index for all trials
210-
if n % TRIALS == n // TRIALS:
211-
torch_tensorrt.runtime.set_cudagraphs_mode(True)
212-
else:
213-
torch_tensorrt.runtime.set_cudagraphs_mode(False)
214-
215-
result_samples.append(optimized_model(inputs[i]).detach().cpu())
216-
torch_results_samples.append(fx_graph(inputs[i]).detach().cpu())
217-
218-
for i, (optimized_model_results, torch_model_results) in enumerate(
219-
zip(result_samples, torch_results_samples)
220-
):
221-
max_diff = float(
222-
torch.max(torch.abs(optimized_model_results - torch_model_results))
223-
)
224-
self.assertAlmostEqual(
225-
max_diff,
226-
0,
227-
DECIMALS_OF_AGREEMENT,
228-
msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})",
229-
)
230-
231161

232162
if __name__ == "__main__":
233163
run_tests()

0 commit comments

Comments
 (0)