Skip to content

Commit d49cadb

Browse files
committed
feat: Move decomposition default groups to new file
1 parent 5883ef6 commit d49cadb

File tree

2 files changed

+209
-197
lines changed

2 files changed

+209
-197
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from typing import Any, Callable, Dict, Set
2+
3+
import torch
4+
from torch._decomp import core_aten_decompositions
5+
from torch._decomp import get_decompositions as get_torch_decompositions
6+
from torch._ops import OpOverload
7+
8+
aten = torch.ops.aten
9+
10+
_core_aten_decompositions: Dict[
11+
OpOverload, Callable[[Any], Any]
12+
] = core_aten_decompositions()
13+
torch_enabled_decompositions: Set[OpOverload] = {
14+
aten._adaptive_avg_pool2d_backward,
15+
aten.addcdiv,
16+
aten.addcdiv_,
17+
aten.addcmul,
18+
aten.addcmul_,
19+
aten.addr,
20+
aten.aminmax,
21+
aten.arange.default,
22+
aten.arange.start,
23+
aten.avg_pool2d_backward,
24+
aten.binary_cross_entropy,
25+
aten.binary_cross_entropy_backward,
26+
aten.binary_cross_entropy_with_logits,
27+
aten.celu,
28+
aten.col2im,
29+
aten.count_nonzero,
30+
aten.cudnn_batch_norm,
31+
aten.cudnn_batch_norm_backward,
32+
aten.deg2rad,
33+
aten.detach,
34+
aten.diag_embed,
35+
aten.diagonal_backward,
36+
aten.dot,
37+
aten.elu,
38+
aten.elu_backward,
39+
aten._embedding_bag,
40+
aten.embedding_dense_backward,
41+
aten._euclidean_dist.default,
42+
aten.expand_as,
43+
aten.eye,
44+
aten.fill,
45+
aten.frac,
46+
aten._fused_moving_avg_obs_fq_helper,
47+
aten.gelu,
48+
aten.gelu_backward,
49+
aten.glu_backward,
50+
aten.grid_sampler_2d,
51+
aten.hardshrink,
52+
aten.hardshrink_backward,
53+
aten.hardsigmoid,
54+
aten.hardsigmoid_backward,
55+
aten.hardswish,
56+
aten.hardswish_,
57+
aten.hardswish_backward,
58+
aten.hardtanh,
59+
aten.hardtanh_,
60+
aten.hardtanh_backward,
61+
aten.heaviside,
62+
aten.huber_loss,
63+
aten.huber_loss_backward,
64+
aten.im2col,
65+
aten.index_add,
66+
aten.index_add_,
67+
aten.index_copy,
68+
aten.index_copy_,
69+
aten.index_fill,
70+
aten.index_fill_,
71+
aten.index_select,
72+
aten.isneginf,
73+
aten.isposinf,
74+
aten.l1_loss,
75+
aten.leaky_relu,
76+
aten.leaky_relu_,
77+
aten.leaky_relu_backward,
78+
aten.lerp,
79+
aten.linspace,
80+
aten.logaddexp,
81+
aten.logaddexp2,
82+
aten.logit,
83+
aten.logit_backward,
84+
aten.log_sigmoid_backward,
85+
aten.log_sigmoid_forward,
86+
aten._log_softmax,
87+
aten._log_softmax_backward_data,
88+
aten.logspace,
89+
aten.logsumexp.default,
90+
aten.masked_fill,
91+
aten.masked_fill_,
92+
aten.max_pool2d_with_indices_backward,
93+
aten.mish,
94+
aten.mse_loss,
95+
aten.mse_loss_backward,
96+
aten.mv,
97+
aten.mvlgamma,
98+
aten.nansum,
99+
aten.nan_to_num,
100+
aten.narrow,
101+
# TODO: Disable the below operators once freezing is done
102+
aten.native_batch_norm,
103+
aten.native_batch_norm_backward,
104+
aten._native_batch_norm_legit,
105+
aten._native_batch_norm_legit_functional,
106+
aten._native_batch_norm_legit_no_training,
107+
aten.native_dropout_backward,
108+
aten.native_group_norm,
109+
aten.native_group_norm_backward,
110+
aten.native_layer_norm,
111+
aten.native_layer_norm_backward,
112+
aten.new_empty,
113+
aten.new_full,
114+
aten.new_ones,
115+
aten.new_zeros,
116+
aten.nll_loss_backward,
117+
aten.nll_loss_forward,
118+
aten.norm,
119+
aten.ones,
120+
aten.ones_like,
121+
aten._prelu_kernel,
122+
aten._prelu_kernel_backward,
123+
aten._reshape_alias,
124+
aten.rad2deg,
125+
aten.renorm,
126+
aten.renorm_,
127+
aten.rot90,
128+
aten.rsub.Scalar,
129+
aten.rsub.Tensor,
130+
aten.select_backward,
131+
aten.select_scatter,
132+
aten.sgn,
133+
aten.sigmoid_backward,
134+
aten.silu,
135+
aten.silu_,
136+
aten.silu_backward,
137+
aten.sinc,
138+
aten.slice_backward,
139+
aten.smooth_l1_loss,
140+
aten.smooth_l1_loss_backward,
141+
aten.soft_margin_loss,
142+
aten.soft_margin_loss_backward,
143+
aten._softmax,
144+
aten._softmax_backward_data,
145+
aten.softplus,
146+
aten.softplus_backward,
147+
aten.softshrink,
148+
aten.softshrink_backward,
149+
aten.special_entr,
150+
aten.special_log_ndtr,
151+
aten.special_xlog1py,
152+
aten.stack,
153+
aten.t,
154+
aten.tanh_backward,
155+
aten.threshold,
156+
aten.threshold_backward,
157+
aten.trace,
158+
aten.transpose.int,
159+
aten.tril.default,
160+
aten.triu.default,
161+
aten.unfold,
162+
aten.unfold_backward,
163+
aten.unfold_copy,
164+
aten.upsample_bilinear2d,
165+
aten.upsample_bilinear2d.vec,
166+
aten.upsample_nearest2d_backward,
167+
aten.xlogy,
168+
aten.zero,
169+
aten.zero_,
170+
aten.zeros,
171+
aten.zeros_like,
172+
# Non-default convenience decompositions
173+
aten.clamp_min,
174+
aten.clamp_max,
175+
aten.linalg_vector_norm,
176+
aten.full,
177+
aten.repeat,
178+
}
179+
torch_disabled_decompositions: Set[OpOverload] = set()
180+
181+
182+
ENABLED_TORCH_DECOMPOSITIONS: Dict[
183+
OpOverload, Callable[[Any], Any]
184+
] = get_torch_decompositions(torch_enabled_decompositions)
185+
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
186+
187+
188+
def check_decomp_set_invariants() -> None:
189+
"""Validates no overlap between enabled and disabled decomposition sets"""
190+
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)
191+
192+
if overlap:
193+
raise AssertionError(
194+
f"Detected {overlap} registered in both torch_enabled_decompositions "
195+
"and torch_disabled_decompositions. Ensure all operator(s) are in "
196+
"at most one of the two sets."
197+
)
198+
199+
200+
check_decomp_set_invariants()

0 commit comments

Comments
 (0)