Skip to content

Commit 34a190e

Browse files
committed
fix: Add enabled/disabled sets for decompositions
- Add sets to selectively enable or disable decompositions in Torch - Add new runtime argument `enable_experimental_decompositions` to enable all core aten decompositions, or a pre-selected subset thereof - Improve documentation of compilation settings overall
1 parent 0527edd commit 34a190e

File tree

5 files changed

+211
-6
lines changed

5 files changed

+211
-6
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
1212
TRUNCATE_LONG_AND_DOUBLE = False
13+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,33 @@
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
1414
TRUNCATE_LONG_AND_DOUBLE,
15+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1516
)
1617

1718

1819
@dataclass
1920
class CompilationSettings:
21+
"""Compilation settings for Torch-TensorRT Dynamo Paths
22+
23+
Args:
24+
precision (torch.dtype): Model Layer precision
25+
debug (bool): Whether to print out verbose debugging information
26+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
27+
min_block_size (int): Minimum number of operators per TRT-Engine Block
28+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
29+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
30+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
31+
version_compatible (bool): Provide version forward-compatibility for engine plan files
32+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
33+
searching for more optimization options. TRT defaults to 3
34+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
35+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
36+
argument as None
37+
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
38+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
39+
or only a selected subset of them
40+
"""
41+
2042
precision: torch.dtype = PRECISION
2143
debug: bool = DEBUG
2244
workspace_size: int = WORKSPACE_SIZE
@@ -28,3 +50,4 @@ class CompilationSettings:
2850
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2951
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3052
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
53+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def aot_torch_tensorrt_aten_backend(
5555
gm,
5656
sample_inputs,
5757
fw_compiler=make_boxed_compiler(custom_backend),
58-
decompositions=get_decompositions(),
58+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
5959
)
6060

6161

py/torch_tensorrt/dynamo/compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
3333
TRUNCATE_LONG_AND_DOUBLE,
34+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
3435
)
3536

3637

@@ -64,6 +65,7 @@ def compile(
6465
version_compatible=VERSION_COMPATIBLE,
6566
optimization_level=OPTIMIZATION_LEVEL,
6667
use_python_runtime=USE_PYTHON_RUNTIME,
68+
enable_experimental_decompositions=ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
6769
**kwargs,
6870
):
6971
if debug:
@@ -73,7 +75,7 @@ def compile(
7375
"The Dynamo backend is an experimental feature, for which only the "
7476
+ "following arguments are supported: "
7577
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
76-
+ "torch_executed_ops, pass_through_build_failures}"
78+
+ "torch_executed_ops, pass_through_build_failures, enable_experimental_decompositions}"
7779
)
7880

7981
if not isinstance(inputs, collections.abc.Sequence):
@@ -111,6 +113,7 @@ def compile(
111113
"optimization_level": optimization_level,
112114
"use_python_runtime": use_python_runtime,
113115
"truncate_long_and_double": truncate_long_and_double,
116+
"enable_experimental_decompositions": enable_experimental_decompositions,
114117
}
115118

116119
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,183 @@
1+
from typing import Callable, Dict
12
import torch
23
from torch._decomp import register_decomposition, core_aten_decompositions
34

5+
aten = torch.ops.aten
46

5-
DECOMPOSITIONS = {**core_aten_decompositions()}
7+
_core_aten_decompositions = core_aten_decompositions()
8+
enabled_decompositions = {
9+
aten._adaptive_avg_pool2d_backward,
10+
aten.addcdiv,
11+
aten.addcdiv_,
12+
aten.addcmul,
13+
aten.addcmul_,
14+
aten.addr,
15+
aten.aminmax,
16+
aten.arange.default,
17+
aten.arange.start,
18+
aten.avg_pool2d_backward,
19+
aten.binary_cross_entropy,
20+
aten.binary_cross_entropy_backward,
21+
aten.binary_cross_entropy_with_logits,
22+
aten.celu,
23+
aten.col2im,
24+
aten.count_nonzero,
25+
aten.cudnn_batch_norm,
26+
aten.cudnn_batch_norm_backward,
27+
aten.deg2rad,
28+
aten.detach,
29+
aten.diag_embed,
30+
aten.diagonal_backward,
31+
aten.dot,
32+
aten.elu,
33+
aten.elu_backward,
34+
aten._embedding_bag,
35+
aten.embedding_dense_backward,
36+
aten._euclidean_dist.default,
37+
aten.expand_as,
38+
aten.eye,
39+
aten.fill,
40+
aten.frac,
41+
aten._fused_moving_avg_obs_fq_helper,
42+
aten.gelu,
43+
aten.gelu_backward,
44+
aten.glu_backward,
45+
aten.grid_sampler_2d,
46+
aten.hardshrink,
47+
aten.hardshrink_backward,
48+
aten.hardsigmoid,
49+
aten.hardsigmoid_backward,
50+
aten.hardswish,
51+
aten.hardswish_,
52+
aten.hardswish_backward,
53+
aten.hardtanh,
54+
aten.hardtanh_,
55+
aten.hardtanh_backward,
56+
aten.heaviside,
57+
aten.huber_loss,
58+
aten.huber_loss_backward,
59+
aten.im2col,
60+
aten.index_add,
61+
aten.index_add_,
62+
aten.index_copy,
63+
aten.index_copy_,
64+
aten.index_fill,
65+
aten.index_fill_,
66+
aten.index_select,
67+
aten.isneginf,
68+
aten.isposinf,
69+
aten.l1_loss,
70+
aten.leaky_relu,
71+
aten.leaky_relu_,
72+
aten.leaky_relu_backward,
73+
aten.lerp,
74+
aten.linspace,
75+
aten.logaddexp,
76+
aten.logaddexp2,
77+
aten.logit,
78+
aten.logit_backward,
79+
aten.log_sigmoid_backward,
80+
aten.log_sigmoid_forward,
81+
aten._log_softmax,
82+
aten._log_softmax_backward_data,
83+
aten.logspace,
84+
aten.logsumexp.default,
85+
aten.masked_fill,
86+
aten.masked_fill_,
87+
aten.max_pool2d_with_indices_backward,
88+
aten.mish,
89+
aten.mse_loss,
90+
aten.mse_loss_backward,
91+
aten.mv,
92+
aten.mvlgamma,
93+
aten.nansum,
94+
aten.nan_to_num,
95+
aten.narrow,
96+
# TODO: Disable the below operators once freezing is done
97+
aten.native_batch_norm,
98+
aten.native_batch_norm_backward,
99+
aten._native_batch_norm_legit,
100+
aten._native_batch_norm_legit_functional,
101+
aten._native_batch_norm_legit_no_training,
102+
aten.native_dropout_backward,
103+
aten.native_group_norm,
104+
aten.native_group_norm_backward,
105+
aten.native_layer_norm,
106+
aten.native_layer_norm_backward,
107+
aten.new_empty,
108+
aten.new_full,
109+
aten.new_ones,
110+
aten.new_zeros,
111+
aten.nll_loss_backward,
112+
aten.nll_loss_forward,
113+
aten.norm,
114+
aten.ones,
115+
aten.ones_like,
116+
aten._prelu_kernel,
117+
aten._prelu_kernel_backward,
118+
aten._reshape_alias,
119+
aten.rad2deg,
120+
aten.renorm,
121+
aten.renorm_,
122+
aten.rot90,
123+
aten.rsub.Scalar,
124+
aten.select_backward,
125+
aten.select_scatter,
126+
aten.sgn,
127+
aten.sigmoid_backward,
128+
aten.silu,
129+
aten.silu_,
130+
aten.silu_backward,
131+
aten.sinc,
132+
aten.slice_backward,
133+
aten.smooth_l1_loss,
134+
aten.smooth_l1_loss_backward,
135+
aten.soft_margin_loss,
136+
aten.soft_margin_loss_backward,
137+
aten._softmax,
138+
aten._softmax_backward_data,
139+
aten.softplus,
140+
aten.softplus_backward,
141+
aten.softshrink,
142+
aten.softshrink_backward,
143+
aten.special_entr,
144+
aten.special_log_ndtr,
145+
aten.special_xlog1py,
146+
aten.stack,
147+
aten.t,
148+
aten.tanh_backward,
149+
aten.threshold,
150+
aten.threshold_backward,
151+
aten.trace,
152+
aten.transpose.int,
153+
aten.tril.default,
154+
aten.triu.default,
155+
aten.unfold,
156+
aten.unfold_backward,
157+
aten.unfold_copy,
158+
aten.upsample_bilinear2d,
159+
aten.upsample_bilinear2d.vec,
160+
aten.upsample_nearest2d_backward,
161+
aten.xlogy,
162+
aten.zero,
163+
aten.zero_,
164+
aten.zeros,
165+
aten.zeros_like,
166+
}
167+
disabled_decompositions = {
168+
aten.rsub.Tensor,
169+
}
6170

7-
aten = torch.ops.aten
171+
DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {
172+
enabled: _core_aten_decompositions[enabled]
173+
for enabled in enabled_decompositions
174+
if enabled in _core_aten_decompositions
175+
}
176+
EXPERIMENTAL_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {
177+
decomp: _core_aten_decompositions[decomp]
178+
for decomp in _core_aten_decompositions
179+
if decomp not in disabled_decompositions
180+
}
8181

9182

10183
def replace_inplace_op(aten_op, outplace_op):
@@ -77,5 +250,10 @@ def reciprocal_replacement(
77250
return torch.div(1, input_)
78251

79252

80-
def get_decompositions():
81-
return DECOMPOSITIONS
253+
def get_decompositions(
254+
enable_experimental_decompositions: bool = False,
255+
) -> Dict[torch._ops.OpOverload, Callable]:
256+
if enable_experimental_decompositions:
257+
return EXPERIMENTAL_DECOMPOSITIONS
258+
else:
259+
return DECOMPOSITIONS

0 commit comments

Comments
 (0)