|
| 1 | +from typing import Callable, Dict |
1 | 2 | import torch
|
2 | 3 | from torch._decomp import register_decomposition, core_aten_decompositions
|
3 | 4 |
|
| 5 | +aten = torch.ops.aten |
4 | 6 |
|
5 |
| -DECOMPOSITIONS = {**core_aten_decompositions()} |
| 7 | +_core_aten_decompositions = core_aten_decompositions() |
| 8 | +enabled_decompositions = { |
| 9 | + aten._adaptive_avg_pool2d_backward, |
| 10 | + aten.addcdiv, |
| 11 | + aten.addcdiv_, |
| 12 | + aten.addcmul, |
| 13 | + aten.addcmul_, |
| 14 | + aten.addr, |
| 15 | + aten.aminmax, |
| 16 | + aten.arange.default, |
| 17 | + aten.arange.start, |
| 18 | + aten.avg_pool2d_backward, |
| 19 | + aten.binary_cross_entropy, |
| 20 | + aten.binary_cross_entropy_backward, |
| 21 | + aten.binary_cross_entropy_with_logits, |
| 22 | + aten.celu, |
| 23 | + aten.col2im, |
| 24 | + aten.count_nonzero, |
| 25 | + aten.cudnn_batch_norm, |
| 26 | + aten.cudnn_batch_norm_backward, |
| 27 | + aten.deg2rad, |
| 28 | + aten.detach, |
| 29 | + aten.diag_embed, |
| 30 | + aten.diagonal_backward, |
| 31 | + aten.dot, |
| 32 | + aten.elu, |
| 33 | + aten.elu_backward, |
| 34 | + aten._embedding_bag, |
| 35 | + aten.embedding_dense_backward, |
| 36 | + aten._euclidean_dist.default, |
| 37 | + aten.expand_as, |
| 38 | + aten.eye, |
| 39 | + aten.fill, |
| 40 | + aten.frac, |
| 41 | + aten._fused_moving_avg_obs_fq_helper, |
| 42 | + aten.gelu, |
| 43 | + aten.gelu_backward, |
| 44 | + aten.glu_backward, |
| 45 | + aten.grid_sampler_2d, |
| 46 | + aten.hardshrink, |
| 47 | + aten.hardshrink_backward, |
| 48 | + aten.hardsigmoid, |
| 49 | + aten.hardsigmoid_backward, |
| 50 | + aten.hardswish, |
| 51 | + aten.hardswish_, |
| 52 | + aten.hardswish_backward, |
| 53 | + aten.hardtanh, |
| 54 | + aten.hardtanh_, |
| 55 | + aten.hardtanh_backward, |
| 56 | + aten.heaviside, |
| 57 | + aten.huber_loss, |
| 58 | + aten.huber_loss_backward, |
| 59 | + aten.im2col, |
| 60 | + aten.index_add, |
| 61 | + aten.index_add_, |
| 62 | + aten.index_copy, |
| 63 | + aten.index_copy_, |
| 64 | + aten.index_fill, |
| 65 | + aten.index_fill_, |
| 66 | + aten.index_select, |
| 67 | + aten.isneginf, |
| 68 | + aten.isposinf, |
| 69 | + aten.l1_loss, |
| 70 | + aten.leaky_relu, |
| 71 | + aten.leaky_relu_, |
| 72 | + aten.leaky_relu_backward, |
| 73 | + aten.lerp, |
| 74 | + aten.linspace, |
| 75 | + aten.logaddexp, |
| 76 | + aten.logaddexp2, |
| 77 | + aten.logit, |
| 78 | + aten.logit_backward, |
| 79 | + aten.log_sigmoid_backward, |
| 80 | + aten.log_sigmoid_forward, |
| 81 | + aten._log_softmax, |
| 82 | + aten._log_softmax_backward_data, |
| 83 | + aten.logspace, |
| 84 | + aten.logsumexp.default, |
| 85 | + aten.masked_fill, |
| 86 | + aten.masked_fill_, |
| 87 | + aten.max_pool2d_with_indices_backward, |
| 88 | + aten.mish, |
| 89 | + aten.mse_loss, |
| 90 | + aten.mse_loss_backward, |
| 91 | + aten.mv, |
| 92 | + aten.mvlgamma, |
| 93 | + aten.nansum, |
| 94 | + aten.nan_to_num, |
| 95 | + aten.narrow, |
| 96 | + # TODO: Disable the below operators once freezing is done |
| 97 | + aten.native_batch_norm, |
| 98 | + aten.native_batch_norm_backward, |
| 99 | + aten._native_batch_norm_legit, |
| 100 | + aten._native_batch_norm_legit_functional, |
| 101 | + aten._native_batch_norm_legit_no_training, |
| 102 | + aten.native_dropout_backward, |
| 103 | + aten.native_group_norm, |
| 104 | + aten.native_group_norm_backward, |
| 105 | + aten.native_layer_norm, |
| 106 | + aten.native_layer_norm_backward, |
| 107 | + aten.new_empty, |
| 108 | + aten.new_full, |
| 109 | + aten.new_ones, |
| 110 | + aten.new_zeros, |
| 111 | + aten.nll_loss_backward, |
| 112 | + aten.nll_loss_forward, |
| 113 | + aten.norm, |
| 114 | + aten.ones, |
| 115 | + aten.ones_like, |
| 116 | + aten._prelu_kernel, |
| 117 | + aten._prelu_kernel_backward, |
| 118 | + aten._reshape_alias, |
| 119 | + aten.rad2deg, |
| 120 | + aten.renorm, |
| 121 | + aten.renorm_, |
| 122 | + aten.rot90, |
| 123 | + aten.rsub.Scalar, |
| 124 | + aten.select_backward, |
| 125 | + aten.select_scatter, |
| 126 | + aten.sgn, |
| 127 | + aten.sigmoid_backward, |
| 128 | + aten.silu, |
| 129 | + aten.silu_, |
| 130 | + aten.silu_backward, |
| 131 | + aten.sinc, |
| 132 | + aten.slice_backward, |
| 133 | + aten.smooth_l1_loss, |
| 134 | + aten.smooth_l1_loss_backward, |
| 135 | + aten.soft_margin_loss, |
| 136 | + aten.soft_margin_loss_backward, |
| 137 | + aten._softmax, |
| 138 | + aten._softmax_backward_data, |
| 139 | + aten.softplus, |
| 140 | + aten.softplus_backward, |
| 141 | + aten.softshrink, |
| 142 | + aten.softshrink_backward, |
| 143 | + aten.special_entr, |
| 144 | + aten.special_log_ndtr, |
| 145 | + aten.special_xlog1py, |
| 146 | + aten.stack, |
| 147 | + aten.t, |
| 148 | + aten.tanh_backward, |
| 149 | + aten.threshold, |
| 150 | + aten.threshold_backward, |
| 151 | + aten.trace, |
| 152 | + aten.transpose.int, |
| 153 | + aten.tril.default, |
| 154 | + aten.triu.default, |
| 155 | + aten.unfold, |
| 156 | + aten.unfold_backward, |
| 157 | + aten.unfold_copy, |
| 158 | + aten.upsample_bilinear2d, |
| 159 | + aten.upsample_bilinear2d.vec, |
| 160 | + aten.upsample_nearest2d_backward, |
| 161 | + aten.xlogy, |
| 162 | + aten.zero, |
| 163 | + aten.zero_, |
| 164 | + aten.zeros, |
| 165 | + aten.zeros_like, |
| 166 | +} |
| 167 | +disabled_decompositions = { |
| 168 | + aten.rsub.Tensor, |
| 169 | +} |
6 | 170 |
|
7 |
| -aten = torch.ops.aten |
| 171 | +DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = { |
| 172 | + enabled: _core_aten_decompositions[enabled] |
| 173 | + for enabled in enabled_decompositions |
| 174 | + if enabled in _core_aten_decompositions |
| 175 | +} |
| 176 | +EXPERIMENTAL_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = { |
| 177 | + decomp: _core_aten_decompositions[decomp] |
| 178 | + for decomp in _core_aten_decompositions |
| 179 | + if decomp not in disabled_decompositions |
| 180 | +} |
8 | 181 |
|
9 | 182 |
|
10 | 183 | def replace_inplace_op(aten_op, outplace_op):
|
@@ -77,5 +250,10 @@ def reciprocal_replacement(
|
77 | 250 | return torch.div(1, input_)
|
78 | 251 |
|
79 | 252 |
|
80 |
| -def get_decompositions(): |
81 |
| - return DECOMPOSITIONS |
| 253 | +def get_decompositions( |
| 254 | + enable_experimental_decompositions: bool = False, |
| 255 | +) -> Dict[torch._ops.OpOverload, Callable]: |
| 256 | + if enable_experimental_decompositions: |
| 257 | + return EXPERIMENTAL_DECOMPOSITIONS |
| 258 | + else: |
| 259 | + return DECOMPOSITIONS |
0 commit comments