Skip to content

feat: Add Selective ATen decompositions #2173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand All @@ -19,6 +20,27 @@

@dataclass
class CompilationSettings:
"""Compilation settings for Torch-TensorRT Dynamo Paths

Args:
precision (torch.dtype): Model Layer precision
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
"""

precision: torch.dtype = PRECISION
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
Expand All @@ -31,3 +53,4 @@ class CompilationSettings:
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def aot_torch_tensorrt_aten_backend(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(),
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)


Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -63,6 +64,7 @@ def compile(
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
Expand All @@ -72,9 +74,10 @@ def compile(

logger.warning(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
"following arguments are supported: "
"{enabled_precisions, debug, workspace_size, min_block_size, "
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
"enable_experimental_decompositions}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -115,6 +118,7 @@ def compile(
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
}

settings = CompilationSettings(**compilation_options)
Expand Down
200 changes: 200 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Any, Callable, Dict, Set

import torch
from torch._decomp import core_aten_decompositions
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload

aten = torch.ops.aten

_core_aten_decompositions: Dict[
OpOverload, Callable[[Any], Any]
] = core_aten_decompositions()
torch_enabled_decompositions: Set[OpOverload] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.aminmax,
aten.arange.default,
aten.arange.start,
aten.avg_pool2d_backward,
aten.binary_cross_entropy,
aten.binary_cross_entropy_backward,
aten.binary_cross_entropy_with_logits,
aten.celu,
aten.col2im,
aten.count_nonzero,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.deg2rad,
aten.detach,
aten.diag_embed,
aten.diagonal_backward,
aten.dot,
aten.elu,
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten._euclidean_dist.default,
aten.expand_as,
aten.eye,
aten.fill,
aten.frac,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.grid_sampler_2d,
aten.hardshrink,
aten.hardshrink_backward,
aten.hardsigmoid,
aten.hardsigmoid_backward,
aten.hardswish,
aten.hardswish_,
aten.hardswish_backward,
aten.hardtanh,
aten.hardtanh_,
aten.hardtanh_backward,
aten.heaviside,
aten.huber_loss,
aten.huber_loss_backward,
aten.im2col,
aten.index_add,
aten.index_add_,
aten.index_copy,
aten.index_copy_,
aten.index_fill,
aten.index_fill_,
aten.index_select,
aten.isneginf,
aten.isposinf,
aten.l1_loss,
aten.leaky_relu,
aten.leaky_relu_,
aten.leaky_relu_backward,
aten.lerp,
aten.linspace,
aten.logaddexp,
aten.logaddexp2,
aten.logit,
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
aten.masked_fill,
aten.masked_fill_,
aten.max_pool2d_with_indices_backward,
aten.mish,
aten.mse_loss,
aten.mse_loss_backward,
aten.mv,
aten.mvlgamma,
aten.nansum,
aten.nan_to_num,
aten.narrow,
# TODO: Disable the below operators once freezing is done
aten.native_batch_norm,
aten.native_batch_norm_backward,
aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
aten.native_dropout_backward,
aten.native_group_norm,
aten.native_group_norm_backward,
aten.native_layer_norm,
aten.native_layer_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,
aten.new_zeros,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.norm,
aten.ones,
aten.ones_like,
aten._prelu_kernel,
aten._prelu_kernel_backward,
aten._reshape_alias,
aten.rad2deg,
aten.renorm,
aten.renorm_,
aten.rot90,
aten.rsub.Scalar,
aten.rsub.Tensor,
aten.select_backward,
aten.select_scatter,
aten.sgn,
aten.sigmoid_backward,
aten.silu,
aten.silu_,
aten.silu_backward,
aten.sinc,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
aten._softmax,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
aten.softshrink,
aten.softshrink_backward,
aten.special_entr,
aten.special_log_ndtr,
aten.special_xlog1py,
aten.stack,
aten.t,
aten.tanh_backward,
aten.threshold,
aten.threshold_backward,
aten.trace,
aten.transpose.int,
aten.tril.default,
aten.triu.default,
aten.unfold,
aten.unfold_backward,
aten.unfold_copy,
aten.upsample_bilinear2d,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.xlogy,
aten.zero,
aten.zero_,
aten.zeros,
aten.zeros_like,
# Non-default convenience decompositions
aten.clamp_min,
aten.clamp_max,
aten.linalg_vector_norm,
aten.full,
aten.repeat,
}
torch_disabled_decompositions: Set[OpOverload] = set()


ENABLED_TORCH_DECOMPOSITIONS: Dict[
OpOverload, Callable[[Any], Any]
] = get_torch_decompositions(torch_enabled_decompositions)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)

if overlap:
raise AssertionError(
f"Detected {overlap} registered in both torch_enabled_decompositions "
"and torch_disabled_decompositions. Ensure all operator(s) are in "
"at most one of the two sets."
)


check_decomp_set_invariants()
Loading