|
| 1 | +from typing import Any, Callable, Dict, Set |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch._decomp import core_aten_decompositions |
| 5 | +from torch._decomp import get_decompositions as get_torch_decompositions |
| 6 | +from torch._ops import OpOverload |
| 7 | + |
| 8 | +aten = torch.ops.aten |
| 9 | + |
| 10 | +_core_aten_decompositions: Dict[ |
| 11 | + OpOverload, Callable[[Any], Any] |
| 12 | +] = core_aten_decompositions() |
| 13 | +torch_enabled_decompositions: Set[OpOverload] = { |
| 14 | + aten._adaptive_avg_pool2d_backward, |
| 15 | + aten.addcdiv, |
| 16 | + aten.addcdiv_, |
| 17 | + aten.addcmul, |
| 18 | + aten.addcmul_, |
| 19 | + aten.addr, |
| 20 | + aten.aminmax, |
| 21 | + aten.arange.default, |
| 22 | + aten.arange.start, |
| 23 | + aten.avg_pool2d_backward, |
| 24 | + aten.binary_cross_entropy, |
| 25 | + aten.binary_cross_entropy_backward, |
| 26 | + aten.binary_cross_entropy_with_logits, |
| 27 | + aten.celu, |
| 28 | + aten.col2im, |
| 29 | + aten.count_nonzero, |
| 30 | + aten.cudnn_batch_norm, |
| 31 | + aten.cudnn_batch_norm_backward, |
| 32 | + aten.deg2rad, |
| 33 | + aten.detach, |
| 34 | + aten.diag_embed, |
| 35 | + aten.diagonal_backward, |
| 36 | + aten.dot, |
| 37 | + aten.elu, |
| 38 | + aten.elu_backward, |
| 39 | + aten._embedding_bag, |
| 40 | + aten.embedding_dense_backward, |
| 41 | + aten._euclidean_dist.default, |
| 42 | + aten.expand_as, |
| 43 | + aten.eye, |
| 44 | + aten.fill, |
| 45 | + aten.frac, |
| 46 | + aten._fused_moving_avg_obs_fq_helper, |
| 47 | + aten.gelu, |
| 48 | + aten.gelu_backward, |
| 49 | + aten.glu_backward, |
| 50 | + aten.grid_sampler_2d, |
| 51 | + aten.hardshrink, |
| 52 | + aten.hardshrink_backward, |
| 53 | + aten.hardsigmoid, |
| 54 | + aten.hardsigmoid_backward, |
| 55 | + aten.hardswish, |
| 56 | + aten.hardswish_, |
| 57 | + aten.hardswish_backward, |
| 58 | + aten.hardtanh, |
| 59 | + aten.hardtanh_, |
| 60 | + aten.hardtanh_backward, |
| 61 | + aten.heaviside, |
| 62 | + aten.huber_loss, |
| 63 | + aten.huber_loss_backward, |
| 64 | + aten.im2col, |
| 65 | + aten.index_add, |
| 66 | + aten.index_add_, |
| 67 | + aten.index_copy, |
| 68 | + aten.index_copy_, |
| 69 | + aten.index_fill, |
| 70 | + aten.index_fill_, |
| 71 | + aten.index_select, |
| 72 | + aten.isneginf, |
| 73 | + aten.isposinf, |
| 74 | + aten.l1_loss, |
| 75 | + aten.leaky_relu, |
| 76 | + aten.leaky_relu_, |
| 77 | + aten.leaky_relu_backward, |
| 78 | + aten.lerp, |
| 79 | + aten.linspace, |
| 80 | + aten.logaddexp, |
| 81 | + aten.logaddexp2, |
| 82 | + aten.logit, |
| 83 | + aten.logit_backward, |
| 84 | + aten.log_sigmoid_backward, |
| 85 | + aten.log_sigmoid_forward, |
| 86 | + aten._log_softmax, |
| 87 | + aten._log_softmax_backward_data, |
| 88 | + aten.logspace, |
| 89 | + aten.logsumexp.default, |
| 90 | + aten.masked_fill, |
| 91 | + aten.masked_fill_, |
| 92 | + aten.max_pool2d_with_indices_backward, |
| 93 | + aten.mish, |
| 94 | + aten.mse_loss, |
| 95 | + aten.mse_loss_backward, |
| 96 | + aten.mv, |
| 97 | + aten.mvlgamma, |
| 98 | + aten.nansum, |
| 99 | + aten.nan_to_num, |
| 100 | + aten.narrow, |
| 101 | + # TODO: Disable the below operators once freezing is done |
| 102 | + aten.native_batch_norm, |
| 103 | + aten.native_batch_norm_backward, |
| 104 | + aten._native_batch_norm_legit, |
| 105 | + aten._native_batch_norm_legit_functional, |
| 106 | + aten._native_batch_norm_legit_no_training, |
| 107 | + aten.native_dropout_backward, |
| 108 | + aten.native_group_norm, |
| 109 | + aten.native_group_norm_backward, |
| 110 | + aten.native_layer_norm, |
| 111 | + aten.native_layer_norm_backward, |
| 112 | + aten.new_empty, |
| 113 | + aten.new_full, |
| 114 | + aten.new_ones, |
| 115 | + aten.new_zeros, |
| 116 | + aten.nll_loss_backward, |
| 117 | + aten.nll_loss_forward, |
| 118 | + aten.norm, |
| 119 | + aten.ones, |
| 120 | + aten.ones_like, |
| 121 | + aten._prelu_kernel, |
| 122 | + aten._prelu_kernel_backward, |
| 123 | + aten._reshape_alias, |
| 124 | + aten.rad2deg, |
| 125 | + aten.renorm, |
| 126 | + aten.renorm_, |
| 127 | + aten.rot90, |
| 128 | + aten.rsub.Scalar, |
| 129 | + aten.rsub.Tensor, |
| 130 | + aten.select_backward, |
| 131 | + aten.select_scatter, |
| 132 | + aten.sgn, |
| 133 | + aten.sigmoid_backward, |
| 134 | + aten.silu, |
| 135 | + aten.silu_, |
| 136 | + aten.silu_backward, |
| 137 | + aten.sinc, |
| 138 | + aten.slice_backward, |
| 139 | + aten.smooth_l1_loss, |
| 140 | + aten.smooth_l1_loss_backward, |
| 141 | + aten.soft_margin_loss, |
| 142 | + aten.soft_margin_loss_backward, |
| 143 | + aten._softmax, |
| 144 | + aten._softmax_backward_data, |
| 145 | + aten.softplus, |
| 146 | + aten.softplus_backward, |
| 147 | + aten.softshrink, |
| 148 | + aten.softshrink_backward, |
| 149 | + aten.special_entr, |
| 150 | + aten.special_log_ndtr, |
| 151 | + aten.special_xlog1py, |
| 152 | + aten.stack, |
| 153 | + aten.t, |
| 154 | + aten.tanh_backward, |
| 155 | + aten.threshold, |
| 156 | + aten.threshold_backward, |
| 157 | + aten.trace, |
| 158 | + aten.transpose.int, |
| 159 | + aten.tril.default, |
| 160 | + aten.triu.default, |
| 161 | + aten.unfold, |
| 162 | + aten.unfold_backward, |
| 163 | + aten.unfold_copy, |
| 164 | + aten.upsample_bilinear2d, |
| 165 | + aten.upsample_bilinear2d.vec, |
| 166 | + aten.upsample_nearest2d_backward, |
| 167 | + aten.xlogy, |
| 168 | + aten.zero, |
| 169 | + aten.zero_, |
| 170 | + aten.zeros, |
| 171 | + aten.zeros_like, |
| 172 | + # Non-default convenience decompositions |
| 173 | + aten.clamp_min, |
| 174 | + aten.clamp_max, |
| 175 | + aten.linalg_vector_norm, |
| 176 | + aten.full, |
| 177 | + aten.repeat, |
| 178 | +} |
| 179 | +torch_disabled_decompositions: Set[OpOverload] = set() |
| 180 | + |
| 181 | + |
| 182 | +ENABLED_TORCH_DECOMPOSITIONS: Dict[ |
| 183 | + OpOverload, Callable[[Any], Any] |
| 184 | +] = get_torch_decompositions(torch_enabled_decompositions) |
| 185 | +TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} |
| 186 | + |
| 187 | + |
| 188 | +def check_decomp_set_invariants() -> None: |
| 189 | + """Validates no overlap between enabled and disabled decomposition sets""" |
| 190 | + overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) |
| 191 | + |
| 192 | + if overlap: |
| 193 | + raise AssertionError( |
| 194 | + f"Detected {overlap} registered in both torch_enabled_decompositions " |
| 195 | + "and torch_disabled_decompositions. Ensure all operator(s) are in " |
| 196 | + "at most one of the two sets." |
| 197 | + ) |
| 198 | + |
| 199 | + |
| 200 | +check_decomp_set_invariants() |
0 commit comments