Skip to content

feat: Add preliminary support for freezing tensors in Dynamo #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 119 additions & 30 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from functools import partial
from typing import Any, Callable, Sequence
import unittest
from typing import Any, Callable, Dict, Optional, Sequence

import torch
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
import torch.utils._pytree as pytree
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import _aot_export_function
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
from torch._ops import OpOverload
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
Expand All @@ -33,31 +37,15 @@ def torch_tensorrt_backend(

DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
return compiled_mod
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)


@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
) -> torch.nn.Module:
settings = parse_dynamo_kwargs(kwargs)

custom_backend = partial(
_pretraced_backend,
settings=settings,
)

# Perform Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_substitutions(gm)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)
return _pretraced_backend(gm, sample_inputs, settings)


def _pretraced_backend(
Expand All @@ -75,22 +63,44 @@ def _pretraced_backend(
Compiled FX GraphModule
"""
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))

# Perform Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_substitutions(gm)

fake_mode = detect_fake_mode(sample_inputs)

# Place backend tracing within FakeTensor context allowing nonfake Tensors
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
# Invoke AOTAutograd to translate operators to aten
graph_module = aot_export_for_compile(
gm,
sample_inputs,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)

trt_compiled = compile_module(
gm,
sample_inputs,
settings=settings,
)
return trt_compiled
except AssertionError:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

constant_fold(graph_module)

trt_compiled = compile_module(
graph_module,
sample_inputs,
settings=settings,
)
return trt_compiled
except (AssertionError, RuntimeError):
if not settings.pass_through_build_failures:
logger.warning(
"TRT conversion failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm.forward
return gm
else:
logger.critical(
"Halting compilation on build failure since "
Expand All @@ -100,3 +110,82 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(gm: torch.fx.GraphModule) -> Any:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197

Folds constants in the graph module, not skipping constructors

Modifies the graph in-place and replaces node with constants
"""
cf = ConstantFolder(gm, skip_constructors=False)
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.nodes:
if node.op == "get_attr" and len(node.users) == 0:
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


def aot_export_for_compile(
func: torch.fx.GraphModule,
args: Sequence[torch.Tensor],
*,
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158

Removed check for input aliasing in resultant subgraph - TRT is functional-only

Exports the function to ATen for torch compile
"""
# Trace function with input arguments and decompositions
with torch.no_grad():
fx_g, metadata, in_spec, out_spec = _aot_export_function(
func,
args,
decompositions=decompositions,
)

# No input mutations
if (
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
!= 0
):
raise RuntimeError(
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
)
# No pytrees
if type(in_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
if type(out_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)

return fx_g
31 changes: 29 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
from torch_tensorrt.fx.observer import Observer
Expand Down Expand Up @@ -169,7 +170,7 @@ def run(

cache = None
if timing_cache:
cache_file = numpy.array(timing_cache)
cache_file = np.array(timing_cache)
cache = builder_config.create_timing_cache(cache_file.tobytes())
else:
cache = builder_config.create_timing_cache(b"")
Expand Down Expand Up @@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
assert self._cur_node_name is not None
return converter(self.network, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.fx.converters import to_numpy

frozen_attr = self.fetch_attr(target)

if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
constant_tensor = frozen_attr

network_constant = to_numpy(constant_tensor)

return network_constant

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
converter = CONVERTERS.get(self._cur_node)
Expand All @@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
else:
outputs = (args[0],)

for output_idx in range(len(outputs)):
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor

output = outputs[output_idx]

if not isinstance(output, trt.tensorrt.ITensor):
new_output = get_trt_tensor(self.network, output, target)
outputs = (
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
)

if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
raise RuntimeError("TensorRT requires all outputs to be Tensor!")

Expand Down
Loading