Skip to content

Commit 5526bca

Browse files
committed
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into export_prototype
2 parents 157bb2d + 0005a31 commit 5526bca

File tree

1 file changed

+40
-45
lines changed

1 file changed

+40
-45
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def aten_ops_fmod(
9494
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
9595

9696

97-
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
97+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc]
9898
def aten_ops_relu(
9999
network: TRTNetwork,
100100
target: Target,
@@ -111,7 +111,7 @@ def aten_ops_relu(
111111
)
112112

113113

114-
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default)
114+
@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc]
115115
def aten_ops_sigmoid(
116116
network: TRTNetwork,
117117
target: Target,
@@ -128,7 +128,7 @@ def aten_ops_sigmoid(
128128
)
129129

130130

131-
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
131+
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc]
132132
def aten_ops_tanh(
133133
network: TRTNetwork,
134134
target: Target,
@@ -145,7 +145,7 @@ def aten_ops_tanh(
145145
)
146146

147147

148-
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default)
148+
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc]
149149
def aten_ops_leaky_relu(
150150
network: TRTNetwork,
151151
target: Target,
@@ -163,7 +163,7 @@ def aten_ops_leaky_relu(
163163
)
164164

165165

166-
@dynamo_tensorrt_converter(torch.ops.aten.elu.default)
166+
@dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc]
167167
def aten_ops_elu(
168168
network: TRTNetwork,
169169
target: Target,
@@ -182,7 +182,7 @@ def aten_ops_elu(
182182
)
183183

184184

185-
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default)
185+
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc]
186186
def aten_ops_softplus(
187187
network: TRTNetwork,
188188
target: Target,
@@ -200,7 +200,7 @@ def aten_ops_softplus(
200200
)
201201

202202

203-
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
203+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc]
204204
def aten_ops_clip(
205205
network: TRTNetwork,
206206
target: Target,
@@ -219,7 +219,7 @@ def aten_ops_clip(
219219
)
220220

221221

222-
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
222+
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc]
223223
def aten_ops_hard_sigmoid(
224224
network: TRTNetwork,
225225
target: Target,
@@ -296,26 +296,20 @@ def aten_ops_rsqrt(
296296
)
297297

298298

299-
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
299+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc]
300300
def aten_ops_neg(
301301
network: TRTNetwork,
302302
target: Target,
303303
args: Tuple[Argument, ...],
304304
kwargs: Dict[str, Argument],
305305
name: str,
306306
) -> Union[TRTTensor, Sequence[TRTTensor]]:
307-
input_val = args[0]
308-
if (isinstance(input_val, TRTTensor)) and (
309-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
310-
):
311-
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
312-
313307
return impl.unary.neg(
314308
network,
315309
target,
316310
SourceIR.ATEN,
317311
name,
318-
input_val,
312+
args[0],
319313
)
320314

321315

@@ -552,8 +546,8 @@ def aten_ops_amax(
552546
)
553547

554548

555-
@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
556-
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
549+
@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
550+
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
557551
def aten_ops_sum(
558552
network: TRTNetwork,
559553
target: Target,
@@ -946,8 +940,8 @@ def aten_ops_isinf(
946940
)
947941

948942

949-
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
950-
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
943+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc]
944+
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc]
951945
def aten_ops_add(
952946
network: TRTNetwork,
953947
target: Target,
@@ -978,8 +972,8 @@ def aten_ops_add(
978972
)
979973

980974

981-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
982-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
975+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc]
976+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc]
983977
def aten_ops_mul(
984978
network: TRTNetwork,
985979
target: Target,
@@ -997,7 +991,7 @@ def aten_ops_mul(
997991
)
998992

999993

1000-
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
994+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc]
1001995
def aten_ops_max(
1002996
network: TRTNetwork,
1003997
target: Target,
@@ -1015,7 +1009,7 @@ def aten_ops_max(
10151009
)
10161010

10171011

1018-
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
1012+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc]
10191013
def aten_ops_min(
10201014
network: TRTNetwork,
10211015
target: Target,
@@ -1033,8 +1027,8 @@ def aten_ops_min(
10331027
)
10341028

10351029

1036-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1037-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
1030+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc]
1031+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc]
10381032
def aten_ops_sub(
10391033
network: TRTNetwork,
10401034
target: Target,
@@ -1065,10 +1059,10 @@ def aten_ops_sub(
10651059
)
10661060

10671061

1068-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1069-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
1070-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
1071-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
1062+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
1063+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
1064+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
1065+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
10721066
def aten_ops_div(
10731067
network: TRTNetwork,
10741068
target: Target,
@@ -1111,9 +1105,9 @@ def aten_ops_div(
11111105
)
11121106

11131107

1114-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1115-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
1116-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
1108+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) # type: ignore[misc]
1109+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc]
1110+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc]
11171111
def aten_ops_pow(
11181112
network: TRTNetwork,
11191113
target: Target,
@@ -1131,8 +1125,8 @@ def aten_ops_pow(
11311125
)
11321126

11331127

1134-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
1135-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
1128+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc]
1129+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc]
11361130
def aten_ops_floor_div(
11371131
network: TRTNetwork,
11381132
target: Target,
@@ -1150,7 +1144,7 @@ def aten_ops_floor_div(
11501144
)
11511145

11521146

1153-
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
1147+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc]
11541148
def aten_ops_logical_and(
11551149
network: TRTNetwork,
11561150
target: Target,
@@ -1168,7 +1162,7 @@ def aten_ops_logical_and(
11681162
)
11691163

11701164

1171-
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1165+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc]
11721166
def aten_ops_logical_or(
11731167
network: TRTNetwork,
11741168
target: Target,
@@ -1186,7 +1180,7 @@ def aten_ops_logical_or(
11861180
)
11871181

11881182

1189-
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1183+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc]
11901184
def aten_ops_logical_xor(
11911185
network: TRTNetwork,
11921186
target: Target,
@@ -1204,8 +1198,8 @@ def aten_ops_logical_xor(
12041198
)
12051199

12061200

1207-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1208-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1201+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
1202+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
12091203
def aten_ops_equal(
12101204
network: TRTNetwork,
12111205
target: Target,
@@ -1223,8 +1217,8 @@ def aten_ops_equal(
12231217
)
12241218

12251219

1226-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1227-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1220+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
1221+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
12281222
def aten_ops_greater(
12291223
network: TRTNetwork,
12301224
target: Target,
@@ -1242,8 +1236,8 @@ def aten_ops_greater(
12421236
)
12431237

12441238

1245-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1246-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1239+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
1240+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
12471241
def aten_ops_less(
12481242
network: TRTNetwork,
12491243
target: Target,
@@ -1292,7 +1286,8 @@ def aten_ops_convolution(
12921286
)
12931287

12941288

1295-
@dynamo_tensorrt_converter(torch.ops.aten.linear.default)
1289+
@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc]
1290+
@dynamo_tensorrt_converter(torch.ops.aten.linear) # type: ignore[misc]
12961291
def aten_ops_linear(
12971292
network: TRTNetwork,
12981293
target: Target,

0 commit comments

Comments
 (0)