Skip to content

Commit

Permalink
skip dummy inference and run_shape_analysis (#3212)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored Oct 29, 2024
1 parent e2a27a0 commit bfa4c9a
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 167 deletions.
23 changes: 18 additions & 5 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,19 +502,24 @@ def save(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
torch.jit.save(module, file_path)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)

# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
Expand All @@ -525,11 +530,19 @@ def save(
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, arg_inputs, kwarg_inputs)
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
Expand Down
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
prepare_inputs,
set_log_level,
Expand Down Expand Up @@ -302,7 +303,6 @@ def compile(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
Expand Down Expand Up @@ -433,6 +433,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
if not settings.use_fast_partitioner:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))

submodule_node_dict = {}
for node in partitioned_module.graph.nodes:
if "_run_on_acc" not in node.name:
continue
submodule_node_dict[node.name] = node

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand All @@ -452,6 +458,26 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
)
continue

if name not in submodule_node_dict:
raise ValueError(
f"node_name: {name} does not exist in the submodule node dictionary"
)

# set the submodule metadata back to the parent trt_module_node
metadata_list = get_output_metadata(submodule)
assert len(metadata_list) > 0
metadata_keys = ["val", "tensor_meta"]
for key in metadata_keys:
if key not in submodule_node_dict[name].meta:
meta_val_list = [
metadata[key] for metadata in metadata_list if key in metadata
]
submodule_node_dict[name].meta[key] = meta_val_list
logger.debug(
f"Updated metadata for node: {name} with its corresponding submodule outputs"
)
break

subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
subgraph_data.subgraph_op_count = len(
Expand Down
41 changes: 11 additions & 30 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import operator
from typing import Any, Dict, Optional, Sequence, Tuple, cast
from typing import Any, Dict, Sequence, Tuple, cast

import torch
from torch._guards import detect_fake_mode
Expand All @@ -16,31 +16,24 @@
OutputSpec,
TensorArgument,
)
from torch_tensorrt.dynamo import partitioning


def export(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
kwarg_inputs: Optional[dict[str, Any]] = None,
) -> ExportedProgram:
"""Export the result of TensorRT compilation into the desired output format.
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
"""
if kwarg_inputs is None:
kwarg_inputs = {}
patched_module = transform(gm, inputs, kwarg_inputs)
patched_module = transform(gm)
exp_program = create_trt_exp_program(patched_module)
return exp_program


def transform(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
kwarg_inputs: Optional[dict[str, Any]] = None,
) -> torch.fx.GraphModule:
"""
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
Expand All @@ -55,14 +48,10 @@ def transform(
"""
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
# This transformed graph is meant to be consumed by `create_trt_exp_program`
if kwarg_inputs is None:
kwarg_inputs = {}
gm = copy.deepcopy(gm)
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwarg_inputs)

# Inline TensorRT submodules
inline_trt_modules(gm, outputs_map)
inline_trt_modules(gm)

# Inline pytorch submodules
inline_torch_modules(gm)
Expand Down Expand Up @@ -361,9 +350,7 @@ def create_trt_exp_program(
return trt_exp_program


def inline_trt_modules(
gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]]
) -> torch.fx.GraphModule:
def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Replace TRT submodules with trt engine nodes.
"""
Expand All @@ -379,7 +366,11 @@ def inline_trt_modules(
trt_module_node = trt_module_node[0]
assert trt_module_node.args

num_outputs = len(outputs_map[trt_module_node.name])
if "val" not in trt_module_node.meta:
raise ValueError(
f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step."
)
num_outputs = len(trt_module_node.meta["val"])
# Insert a call_function node to perform inference on TRT engine
with gm.graph.inserting_before(trt_module_node):
engine_name = f"{name}_engine"
Expand All @@ -390,19 +381,9 @@ def inline_trt_modules(
torch.ops.tensorrt.execute_engine.default,
(trt_module_node.args, engine_node),
)
trt_node.meta["val"] = []
# set trt_node.meta with trt_module_node.meta
assert num_outputs > 0
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
for idx in range(num_outputs):
trt_node.meta["val"].append(
cast(
FakeTensor,
torch.empty_strided(
tuple(outputs_map[trt_module_node.name][idx]),
tuple([1] * len(outputs_map[trt_module_node.name][idx])),
),
)
)
trt_node.meta["val"] = trt_module_node.meta["val"]

# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def construct_refit_mapping(

output_dtypes = infer_module_output_dtypes(
module,
inputs,
settings.device,
truncate_double=settings.truncate_double,
)

Expand Down
69 changes: 9 additions & 60 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
Expand All @@ -17,58 +15,22 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs
from torch_tensorrt.dynamo.utils import get_output_dtypes

logger = logging.getLogger(__name__)


def infer_module_output_dtypes(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
This function get the output dtypes from node.meta['val'] which was set during dynamo compile_module step
and truncates them accordingly.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes
outputs = [node for node in module.graph.nodes if node.op == "output"]
outputs = outputs[0].args
return get_output_dtypes(outputs, truncate_double)


def interpret_module_to_result(
Expand All @@ -91,22 +53,9 @@ def interpret_module_to_result(
Returns:
TRTInterpreterResult
"""
if arg_inputs is not None:
output_dtypes = infer_module_output_dtypes(
module,
arg_inputs,
settings.device,
kwarg_inputs=kwarg_inputs,
truncate_double=settings.truncate_double,
)
else:
# args and kwargs are combined and flattened to one list
output_dtypes = infer_module_output_dtypes(
module,
inputs,
settings.device,
truncate_double=settings.truncate_double,
)
output_dtypes = infer_module_output_dtypes(
module, truncate_double=settings.truncate_double
)

interpreter = TRTInterpreter(
module,
Expand Down
32 changes: 14 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from typing import Callable, Tuple

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata

logger = logging.getLogger(__name__)

Expand All @@ -14,33 +14,29 @@ def lower_linear(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm


def linear_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for linear"""
orig_op = torch.ops.aten.addmm.default
replacement_op = torch.ops.aten.linear.default

# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
W_T = torch.ops.aten.permute.default(weight, [1, 0])
out = torch.ops.aten.addmm.default(bias, input, W_T)
out = orig_op(bias, input, W_T)
return out

# Replacement graph
def replacement(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.linear.default(input, weight, bias)
return replacement_op(input, weight, bias)

metadata = get_metadata(gm, orig_op)
replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement)

if len(replaced_nodes) > 0:
gm = clean_up_graph_after_modifications(gm)
set_metadata(gm, replacement_op, metadata)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return orig, replacement
return gm
Loading

0 comments on commit bfa4c9a

Please sign in to comment.