Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch from fx.symbolic_trace to dynamo_trace for converter test part-1 #3261

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
891e963
enable converter non dynamic shape tests to use dynamo tracer
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
814262f
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
d8e52bf
test
lanluo-nvidia Oct 27, 2024
4d46235
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 27, 2024
8b3842a
test
lanluo-nvidia Oct 27, 2024
7ddf56f
test
lanluo-nvidia Oct 27, 2024
39e0a49
test
lanluo-nvidia Oct 27, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
7a9659f
Merge branch 'lluo/save_remove_inputs' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 29, 2024
cb656bb
Merge branch 'main' into lluo/switch_to_dynamo_trace
lanluo-nvidia Oct 29, 2024
c023714
resolve comments
lanluo-nvidia Oct 30, 2024
594ca28
resolve comments
lanluo-nvidia Oct 30, 2024
56d034b
resolve comments
lanluo-nvidia Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,19 +502,24 @@ def save(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
torch.jit.save(module, file_path)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)

# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
Expand All @@ -525,11 +530,19 @@ def save(
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, arg_inputs, kwarg_inputs)
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
Expand Down
25 changes: 24 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
prepare_inputs,
set_log_level,
Expand Down Expand Up @@ -295,7 +296,6 @@ def compile(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
Expand Down Expand Up @@ -426,6 +426,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
if not settings.use_fast_partitioner:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))

submodule_node_dict = {}
for node in partitioned_module.graph.nodes:
if "_run_on_acc" not in node.name:
continue
submodule_node_dict[node.name] = node

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand All @@ -445,6 +451,23 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
)
continue

if name not in submodule_node_dict:
raise ValueError(
f"node_name: {name} does not exist in the submodule node dictionary"
)

# set the submodule metadata back to the parent trt_module_node
metadata_list = get_output_metadata(submodule)
assert len(metadata_list) > 0
if "val" not in submodule_node_dict[name].meta:
meta_val_list = [
metadata["val"] for metadata in metadata_list if "val" in metadata
]
submodule_node_dict[name].meta["val"] = meta_val_list
logger.debug(
f"Update submodule output metadata back to the parent trt_module_node: {name}"
)

subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
subgraph_data.subgraph_op_count = len(
Expand Down
41 changes: 11 additions & 30 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import operator
from typing import Any, Dict, Optional, Sequence, Tuple, cast
from typing import Any, Dict, Sequence, Tuple, cast

import torch
from torch._guards import detect_fake_mode
Expand All @@ -16,31 +16,24 @@
OutputSpec,
TensorArgument,
)
from torch_tensorrt.dynamo import partitioning


def export(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
kwarg_inputs: Optional[dict[str, Any]] = None,
) -> ExportedProgram:
"""Export the result of TensorRT compilation into the desired output format.

Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
"""
if kwarg_inputs is None:
kwarg_inputs = {}
patched_module = transform(gm, inputs, kwarg_inputs)
patched_module = transform(gm)
exp_program = create_trt_exp_program(patched_module)
return exp_program


def transform(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
kwarg_inputs: Optional[dict[str, Any]] = None,
) -> torch.fx.GraphModule:
"""
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
Expand All @@ -55,14 +48,10 @@ def transform(
"""
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
# This transformed graph is meant to be consumed by `create_trt_exp_program`
if kwarg_inputs is None:
kwarg_inputs = {}
gm = copy.deepcopy(gm)
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwarg_inputs)

# Inline TensorRT submodules
inline_trt_modules(gm, outputs_map)
inline_trt_modules(gm)

# Inline pytorch submodules
inline_torch_modules(gm)
Expand Down Expand Up @@ -361,9 +350,7 @@ def create_trt_exp_program(
return trt_exp_program


def inline_trt_modules(
gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]]
) -> torch.fx.GraphModule:
def inline_trt_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Replace TRT submodules with trt engine nodes.
"""
Expand All @@ -379,7 +366,11 @@ def inline_trt_modules(
trt_module_node = trt_module_node[0]
assert trt_module_node.args

num_outputs = len(outputs_map[trt_module_node.name])
if "val" not in trt_module_node.meta:
raise ValueError(
f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step."
)
num_outputs = len(trt_module_node.meta["val"])
# Insert a call_function node to perform inference on TRT engine
with gm.graph.inserting_before(trt_module_node):
engine_name = f"{name}_engine"
Expand All @@ -390,19 +381,9 @@ def inline_trt_modules(
torch.ops.tensorrt.execute_engine.default,
(trt_module_node.args, engine_node),
)
trt_node.meta["val"] = []
# set trt_node.meta with trt_module_node.meta
assert num_outputs > 0
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
for idx in range(num_outputs):
trt_node.meta["val"].append(
cast(
FakeTensor,
torch.empty_strided(
tuple(outputs_map[trt_module_node.name][idx]),
tuple([1] * len(outputs_map[trt_module_node.name][idx])),
),
)
)
trt_node.meta["val"] = trt_module_node.meta["val"]

# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def construct_refit_mapping(

output_dtypes = infer_module_output_dtypes(
module,
inputs,
settings.device,
truncate_double=settings.truncate_double,
)

Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]
args = list(signature(mod.forward).parameters.keys())
dynamic_shapes = {}
for input, input_name in zip(inputs, args[: len(inputs)]):
# if input.name is not None, also not empty str, use the input.name
if input.name is not None and len(input.name) > 0 and input.name != input_name:
input_name = input.name
dynamic_shapes[input_name] = get_dynamic_shapes(input)
return dynamic_shapes

Expand Down
69 changes: 9 additions & 60 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
Expand All @@ -17,58 +15,22 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs
from torch_tensorrt.dynamo.utils import get_output_dtypes

logger = logging.getLogger(__name__)


def infer_module_output_dtypes(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
This function get the output dtypes from node.meta['val'] which was set during dynamo compile_module step
and truncates them accordingly.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes
outputs = [node for node in module.graph.nodes if node.op == "output"]
outputs = outputs[0].args
return get_output_dtypes(outputs, truncate_double)


def interpret_module_to_result(
Expand All @@ -91,22 +53,9 @@ def interpret_module_to_result(
Returns:
TRTInterpreterResult
"""
if arg_inputs is not None:
output_dtypes = infer_module_output_dtypes(
module,
arg_inputs,
settings.device,
kwarg_inputs=kwarg_inputs,
truncate_double=settings.truncate_double,
)
else:
# args and kwargs are combined and flattened to one list
output_dtypes = infer_module_output_dtypes(
module,
inputs,
settings.device,
truncate_double=settings.truncate_double,
)
output_dtypes = infer_module_output_dtypes(
module, truncate_double=settings.truncate_double
)

interpreter = TRTInterpreter(
module,
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import _decomp_table_to_post_autograd_aten, register_decomposition
from torch._decomp import register_decomposition
from torch._export.utils import _decomp_table_to_post_autograd_aten
from torch._ops import OpOverload
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
Expand Down
32 changes: 14 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from typing import Callable, Tuple

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata

logger = logging.getLogger(__name__)

Expand All @@ -14,33 +14,29 @@ def lower_linear(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm


def linear_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for linear"""
orig_op = torch.ops.aten.addmm.default
replacement_op = torch.ops.aten.linear.default

# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
W_T = torch.ops.aten.permute.default(weight, [1, 0])
out = torch.ops.aten.addmm.default(bias, input, W_T)
out = orig_op(bias, input, W_T)
return out

# Replacement graph
def replacement(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.linear.default(input, weight, bias)
return replacement_op(input, weight, bias)

metadata = get_metadata(gm, orig_op)
replaced_nodes = torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement)

if len(replaced_nodes) > 0:
gm = clean_up_graph_after_modifications(gm)
set_metadata(gm, replacement_op, metadata)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return orig, replacement
return gm
Loading
Loading