Skip to content

Commit a2caefb

Browse files
committed
Revert all changes to py/torch_tensorrt/fx
Revert "fix: Add automatic type promotion for FX ops" This reverts commit f1f3716. Revert "Moved clamp to impl" This reverts commit df401dd. Revert "aten::unsqueeze impl refactor" This reverts commit b424735. Revert "Converter reorg and where operator" This reverts commit b4da15e. Revert "converter reorg and matmul" This reverts commit 7551eee. Revert "converter reorg and slice" This reverts commit 9bbdc9e. Revert "Converter reorg and select operation" This reverts commit fb70253. Revert "Converter reorg and squeeze operator" This reverts commit 294545c. Revert "Converter reorg and gelu" This reverts commit 37d1168. Revert "Converter reorg and softmax operation" This reverts commit 1ba6d13. Revert "layer_norm converter" This reverts commit e0b34b1. Revert "Converter reorg batch norm" This reverts commit 59354e5. Revert "Converter reorg and rsub" This reverts commit db15d27. Revert "Converter reorg fmod" This reverts commit ce3fa67. Revert "Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> (#1905)" This reverts commit 7158ca5. Revert "refactor: Moving elementwise and unary core to impl" This reverts commit 45e43ca.
1 parent 5420be4 commit a2caefb

38 files changed

+850
-3076
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+489-388
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+23-238
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,9 @@
2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2323
from torch_tensorrt.fx.converters.impl import activation, convolution
24-
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
25-
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
26-
from torch_tensorrt.fx.converters.impl.elementwise import fmod
27-
from torch_tensorrt.fx.converters.impl.elementwise import rsub
28-
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
29-
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
30-
from torch_tensorrt.fx.converters.impl.normalization import softmax
31-
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
32-
from torch_tensorrt.fx.converters.impl.select import select
33-
from torch_tensorrt.fx.converters.impl.slice import slice_op
34-
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
35-
from torch_tensorrt.fx.converters.impl.condition import where
36-
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
37-
from torch_tensorrt.fx.converters.impl.elementwise import clamp
3824

3925
_LOGGER: logging.Logger = logging.getLogger(__name__)
4026

41-
42-
def or_none(args, i):
43-
return args[i] if len(args) > i else None
44-
45-
4627
## converter list in alphabetic order
4728
@tensorrt_converter(torch.ops.aten.add.Tensor)
4829
def aten_ops_add(
@@ -108,19 +89,18 @@ def aten_ops_batch_norm(
10889
kwargs: Dict[str, Argument],
10990
name: str,
11091
) -> Union[TRTTensor, Sequence[TRTTensor]]:
111-
return batch_norm(
112-
network,
113-
target,
114-
SourceIR.ATEN,
115-
name,
116-
args[0],
117-
args[1],
118-
args[2],
119-
args[3],
120-
args[4],
121-
args[5],
122-
args[6],
123-
args[7],
92+
kwargs_new = {
93+
"input": args[0],
94+
"weight": args[1],
95+
"bias": args[2],
96+
"running_mean": args[3],
97+
"running_var": args[4],
98+
"training": args[5],
99+
"momentum": args[6],
100+
"eps": args[7],
101+
}
102+
return acc_ops_converters.acc_ops_batch_norm(
103+
network, target, None, kwargs_new, name
124104
)
125105

126106

@@ -202,7 +182,9 @@ def aten_ops_div(
202182
network, target, None, kwargs_new, name
203183
)
204184
elif rounding_mode == "trunc":
205-
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
185+
return acc_ops_converters.acc_ops_trunc_div(
186+
network, target, None, kwargs_new, name
187+
)
206188
else:
207189
raise RuntimeError(
208190
f"Target {target} does not support rounding mode {rounding_mode}"
@@ -260,7 +242,11 @@ def aten_ops_fmod(
260242
kwargs: Dict[str, Argument],
261243
name: str,
262244
) -> Union[TRTTensor, Sequence[TRTTensor]]:
263-
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
245+
kwargs_new = {
246+
"input": args[0],
247+
"other": args[1],
248+
}
249+
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
264250

265251

266252
@tensorrt_converter(torch.ops.aten.hardtanh.default)
@@ -271,40 +257,12 @@ def aten_ops_hardtanh(
271257
kwargs: Dict[str, Argument],
272258
name: str,
273259
) -> Union[TRTTensor, Sequence[TRTTensor]]:
260+
274261
return activation.hardtanh(
275262
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
276263
)
277264

278265

279-
@tensorrt_converter(torch.ops.aten.gelu.default)
280-
def aten_ops_gelu(
281-
network: TRTNetwork,
282-
target: Target,
283-
args: Tuple[Argument, ...],
284-
kwargs: Dict[str, Argument],
285-
name: str,
286-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
287-
return activation.gelu(
288-
network,
289-
target,
290-
SourceIR.ATEN,
291-
name,
292-
args[0],
293-
)
294-
295-
296-
@tensorrt_converter(torch.ops.aten.matmul)
297-
@tensorrt_converter(torch.ops.aten.mm.default)
298-
def aten_ops_matmul(
299-
network: TRTNetwork,
300-
target: Target,
301-
args: Tuple[Argument, ...],
302-
kwargs: Dict[str, Argument],
303-
name: str,
304-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
305-
return matrix_multiply(network, target, SourceIR.ATEN, name, args[0], args[1])
306-
307-
308266
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
309267
def aten_ops_fmod(
310268
network: TRTNetwork,
@@ -328,28 +286,8 @@ def aten_ops_leaky_relu(
328286
kwargs: Dict[str, Argument],
329287
name: str,
330288
) -> Union[TRTTensor, Sequence[TRTTensor]]:
331-
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
332-
333289

334-
@tensorrt_converter(torch.ops.aten.layer_norm.default)
335-
def aten_ops_layernorm(
336-
network: TRTNetwork,
337-
target: Target,
338-
args: Tuple[Argument, ...],
339-
kwargs: Dict[str, Argument],
340-
name: str,
341-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
342-
return layer_norm(
343-
network,
344-
target,
345-
SourceIR.ATEN,
346-
name,
347-
args[0],
348-
args[1],
349-
args[2],
350-
args[3],
351-
args[4],
352-
)
290+
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
353291

354292

355293
@tensorrt_converter(torch.ops.aten.linear)
@@ -452,42 +390,6 @@ def aten_ops_relu(
452390
)
453391

454392

455-
@tensorrt_converter(torch.ops.aten.relu.default)
456-
def aten_ops_relu(
457-
network: TRTNetwork,
458-
target: Target,
459-
args: Tuple[Argument, ...],
460-
kwargs: Dict[str, Argument],
461-
name: str,
462-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
463-
464-
return activation.relu(
465-
network,
466-
target,
467-
SourceIR.ATEN,
468-
name,
469-
args[0],
470-
)
471-
472-
473-
@tensorrt_converter(torch.ops.aten.rsqrt.default)
474-
def aten_ops_rsqrt(
475-
network: TRTNetwork,
476-
target: Target,
477-
args: Tuple[Argument, ...],
478-
kwargs: Dict[str, Argument],
479-
name: str,
480-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
481-
482-
return rsqrt(
483-
network,
484-
target,
485-
SourceIR.ATEN,
486-
name,
487-
args[0],
488-
)
489-
490-
491393
@tensorrt_converter(torch.ops.aten.sub.Tensor)
492394
def aten_ops_sub(
493395
network: TRTNetwork,
@@ -503,29 +405,6 @@ def aten_ops_sub(
503405
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
504406

505407

506-
@tensorrt_converter(torch.ops.aten.squeeze.dim)
507-
@tensorrt_converter(torch.ops.aten.squeeze.dims)
508-
def aten_ops_squeeze(
509-
network: TRTNetwork,
510-
target: Target,
511-
args: Tuple[Argument, ...],
512-
kwargs: Dict[str, Argument],
513-
name: str,
514-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
515-
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
516-
517-
518-
@tensorrt_converter(torch.ops.aten.unsqueeze.default)
519-
def aten_ops_unsqueeze(
520-
network: TRTNetwork,
521-
target: Target,
522-
args: Tuple[Argument, ...],
523-
kwargs: Dict[str, Argument],
524-
name: str,
525-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
526-
return unsqueeze(network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1])
527-
528-
529408
@tensorrt_converter(torch.ops.aten.view.default)
530409
def aten_ops_reshape(
531410
network: TRTNetwork,
@@ -563,31 +442,6 @@ def aten_ops_reshape(
563442
return layer.get_output(0)
564443

565444

566-
@tensorrt_converter(torch.ops.aten.rsub.Tensor)
567-
def aten_ops_rsub(
568-
network: TRTNetwork,
569-
target: Target,
570-
args: Tuple[Argument, ...],
571-
kwargs: Dict[str, Argument],
572-
name: str,
573-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
574-
alpha = None
575-
if "alpha" in kwargs:
576-
alpha = kwargs["alpha"]
577-
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
578-
579-
580-
@tensorrt_converter(torch.ops.aten._softmax.default)
581-
def aten_ops_softmax(
582-
network: TRTNetwork,
583-
target: Target,
584-
args: Tuple[Argument, ...],
585-
kwargs: Dict[str, Argument],
586-
name: str,
587-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
588-
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
589-
590-
591445
@tensorrt_converter(torch.ops.aten.tanh.default)
592446
def aten_ops_tanh(
593447
network: TRTNetwork,
@@ -596,30 +450,12 @@ def aten_ops_tanh(
596450
kwargs: Dict[str, Argument],
597451
name: str,
598452
) -> Union[TRTTensor, Sequence[TRTTensor]]:
599-
return activation.tanh(
600-
network,
601-
target,
602-
SourceIR.ATEN,
603-
name,
604-
args[0],
605-
)
606453

607-
608-
@tensorrt_converter(torch.ops.aten.where.self)
609-
def aten_ops_where(
610-
network: TRTNetwork,
611-
target: Target,
612-
args: Tuple[Argument, ...],
613-
kwargs: Dict[str, Argument],
614-
name: str,
615-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
616-
return where(
454+
return activation.tanh(
617455
network,
618456
target,
619457
SourceIR.ATEN,
620458
name,
621-
args[1],
622-
args[2],
623459
args[0],
624460
)
625461

@@ -639,25 +475,6 @@ def aten_ops_cat(
639475
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
640476

641477

642-
@tensorrt_converter(torch.ops.aten.clamp.default)
643-
def aten_ops_clamp(
644-
network: TRTNetwork,
645-
target: Target,
646-
args: Tuple[Argument, ...],
647-
kwargs: Dict[str, Argument],
648-
name: str,
649-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
650-
return clamp.clamp(
651-
network,
652-
target,
653-
SourceIR.ACC,
654-
name,
655-
input_val=args[0],
656-
min_val=or_none(args, 1),
657-
max_val=or_none(args, 2),
658-
)
659-
660-
661478
@tensorrt_converter(torch.ops.aten.expand.default)
662479
def aten_ops_expand(
663480
network: TRTNetwork,
@@ -720,17 +537,6 @@ def aten_ops_operator_add(
720537
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
721538

722539

723-
@tensorrt_converter(torch.ops.aten.select.int)
724-
def aten_ops_select(
725-
network: TRTNetwork,
726-
target: Target,
727-
args: Tuple[Argument, ...],
728-
kwargs: Dict[str, Argument],
729-
name: str,
730-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
731-
return select(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
732-
733-
734540
@tensorrt_converter(operator.sub)
735541
def aten_ops_operator_sub(
736542
network: TRTNetwork,
@@ -766,27 +572,6 @@ def aten_ops_sym_numel(
766572
return reduce_layer.get_output(0)
767573

768574

769-
@tensorrt_converter(torch.ops.aten.slice.Tensor)
770-
def aten_ops_slice(
771-
network: TRTNetwork,
772-
target: Target,
773-
args: Tuple[Argument, ...],
774-
kwargs: Dict[str, Argument],
775-
name: str,
776-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
777-
return slice_op(
778-
network,
779-
target,
780-
SourceIR.ATEN,
781-
name,
782-
args[0],
783-
args[1],
784-
args[2],
785-
args[3],
786-
args[4],
787-
)
788-
789-
790575
@tensorrt_converter(torch.ops.aten.sym_size)
791576
def aten_ops_sym_size(
792577
network: TRTNetwork,

0 commit comments

Comments
 (0)