21
21
from .converter_utils import * # noqa: F403
22
22
import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
23
23
from torch_tensorrt .fx .converters .impl import activation , convolution
24
- from torch_tensorrt .fx .converters .impl .elementwise import trunc_div
25
- from torch_tensorrt .fx .converters .impl .elementwise import rsqrt
26
- from torch_tensorrt .fx .converters .impl .elementwise import fmod
27
- from torch_tensorrt .fx .converters .impl .elementwise import rsub
28
- from torch_tensorrt .fx .converters .impl .normalization import batch_norm
29
- from torch_tensorrt .fx .converters .impl .normalization import layer_norm
30
- from torch_tensorrt .fx .converters .impl .normalization import softmax
31
- from torch_tensorrt .fx .converters .impl .squeeze import squeeze
32
- from torch_tensorrt .fx .converters .impl .select import select
33
- from torch_tensorrt .fx .converters .impl .slice import slice_op
34
- from torch_tensorrt .fx .converters .impl .matmul import matrix_multiply
35
- from torch_tensorrt .fx .converters .impl .condition import where
36
- from torch_tensorrt .fx .converters .impl .unsqueeze import unsqueeze
37
- from torch_tensorrt .fx .converters .impl .elementwise import clamp
38
24
39
25
_LOGGER : logging .Logger = logging .getLogger (__name__ )
40
26
41
-
42
- def or_none (args , i ):
43
- return args [i ] if len (args ) > i else None
44
-
45
-
46
27
## converter list in alphabetic order
47
28
@tensorrt_converter (torch .ops .aten .add .Tensor )
48
29
def aten_ops_add (
@@ -108,19 +89,18 @@ def aten_ops_batch_norm(
108
89
kwargs : Dict [str , Argument ],
109
90
name : str ,
110
91
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
111
- return batch_norm (
112
- network ,
113
- target ,
114
- SourceIR .ATEN ,
115
- name ,
116
- args [0 ],
117
- args [1 ],
118
- args [2 ],
119
- args [3 ],
120
- args [4 ],
121
- args [5 ],
122
- args [6 ],
123
- args [7 ],
92
+ kwargs_new = {
93
+ "input" : args [0 ],
94
+ "weight" : args [1 ],
95
+ "bias" : args [2 ],
96
+ "running_mean" : args [3 ],
97
+ "running_var" : args [4 ],
98
+ "training" : args [5 ],
99
+ "momentum" : args [6 ],
100
+ "eps" : args [7 ],
101
+ }
102
+ return acc_ops_converters .acc_ops_batch_norm (
103
+ network , target , None , kwargs_new , name
124
104
)
125
105
126
106
@@ -202,7 +182,9 @@ def aten_ops_div(
202
182
network , target , None , kwargs_new , name
203
183
)
204
184
elif rounding_mode == "trunc" :
205
- return trunc_div (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
185
+ return acc_ops_converters .acc_ops_trunc_div (
186
+ network , target , None , kwargs_new , name
187
+ )
206
188
else :
207
189
raise RuntimeError (
208
190
f"Target { target } does not support rounding mode { rounding_mode } "
@@ -260,7 +242,11 @@ def aten_ops_fmod(
260
242
kwargs : Dict [str , Argument ],
261
243
name : str ,
262
244
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
263
- return fmod (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
245
+ kwargs_new = {
246
+ "input" : args [0 ],
247
+ "other" : args [1 ],
248
+ }
249
+ return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
264
250
265
251
266
252
@tensorrt_converter (torch .ops .aten .hardtanh .default )
@@ -271,40 +257,12 @@ def aten_ops_hardtanh(
271
257
kwargs : Dict [str , Argument ],
272
258
name : str ,
273
259
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
260
+
274
261
return activation .hardtanh (
275
262
network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ]
276
263
)
277
264
278
265
279
- @tensorrt_converter (torch .ops .aten .gelu .default )
280
- def aten_ops_gelu (
281
- network : TRTNetwork ,
282
- target : Target ,
283
- args : Tuple [Argument , ...],
284
- kwargs : Dict [str , Argument ],
285
- name : str ,
286
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
287
- return activation .gelu (
288
- network ,
289
- target ,
290
- SourceIR .ATEN ,
291
- name ,
292
- args [0 ],
293
- )
294
-
295
-
296
- @tensorrt_converter (torch .ops .aten .matmul )
297
- @tensorrt_converter (torch .ops .aten .mm .default )
298
- def aten_ops_matmul (
299
- network : TRTNetwork ,
300
- target : Target ,
301
- args : Tuple [Argument , ...],
302
- kwargs : Dict [str , Argument ],
303
- name : str ,
304
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
305
- return matrix_multiply (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
306
-
307
-
308
266
@tensorrt_converter (torch .ops .aten .fmod .Tensor )
309
267
def aten_ops_fmod (
310
268
network : TRTNetwork ,
@@ -328,28 +286,8 @@ def aten_ops_leaky_relu(
328
286
kwargs : Dict [str , Argument ],
329
287
name : str ,
330
288
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
331
- return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
332
-
333
289
334
- @tensorrt_converter (torch .ops .aten .layer_norm .default )
335
- def aten_ops_layernorm (
336
- network : TRTNetwork ,
337
- target : Target ,
338
- args : Tuple [Argument , ...],
339
- kwargs : Dict [str , Argument ],
340
- name : str ,
341
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
342
- return layer_norm (
343
- network ,
344
- target ,
345
- SourceIR .ATEN ,
346
- name ,
347
- args [0 ],
348
- args [1 ],
349
- args [2 ],
350
- args [3 ],
351
- args [4 ],
352
- )
290
+ return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
353
291
354
292
355
293
@tensorrt_converter (torch .ops .aten .linear )
@@ -452,42 +390,6 @@ def aten_ops_relu(
452
390
)
453
391
454
392
455
- @tensorrt_converter (torch .ops .aten .relu .default )
456
- def aten_ops_relu (
457
- network : TRTNetwork ,
458
- target : Target ,
459
- args : Tuple [Argument , ...],
460
- kwargs : Dict [str , Argument ],
461
- name : str ,
462
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
463
-
464
- return activation .relu (
465
- network ,
466
- target ,
467
- SourceIR .ATEN ,
468
- name ,
469
- args [0 ],
470
- )
471
-
472
-
473
- @tensorrt_converter (torch .ops .aten .rsqrt .default )
474
- def aten_ops_rsqrt (
475
- network : TRTNetwork ,
476
- target : Target ,
477
- args : Tuple [Argument , ...],
478
- kwargs : Dict [str , Argument ],
479
- name : str ,
480
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
481
-
482
- return rsqrt (
483
- network ,
484
- target ,
485
- SourceIR .ATEN ,
486
- name ,
487
- args [0 ],
488
- )
489
-
490
-
491
393
@tensorrt_converter (torch .ops .aten .sub .Tensor )
492
394
def aten_ops_sub (
493
395
network : TRTNetwork ,
@@ -503,29 +405,6 @@ def aten_ops_sub(
503
405
return acc_ops_converters .acc_ops_sub (network , target , None , kwargs_new , name )
504
406
505
407
506
- @tensorrt_converter (torch .ops .aten .squeeze .dim )
507
- @tensorrt_converter (torch .ops .aten .squeeze .dims )
508
- def aten_ops_squeeze (
509
- network : TRTNetwork ,
510
- target : Target ,
511
- args : Tuple [Argument , ...],
512
- kwargs : Dict [str , Argument ],
513
- name : str ,
514
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
515
- return squeeze (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
516
-
517
-
518
- @tensorrt_converter (torch .ops .aten .unsqueeze .default )
519
- def aten_ops_unsqueeze (
520
- network : TRTNetwork ,
521
- target : Target ,
522
- args : Tuple [Argument , ...],
523
- kwargs : Dict [str , Argument ],
524
- name : str ,
525
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
526
- return unsqueeze (network , target , SourceIR .ATEN , name , input_t = args [0 ], dim = args [1 ])
527
-
528
-
529
408
@tensorrt_converter (torch .ops .aten .view .default )
530
409
def aten_ops_reshape (
531
410
network : TRTNetwork ,
@@ -563,31 +442,6 @@ def aten_ops_reshape(
563
442
return layer .get_output (0 )
564
443
565
444
566
- @tensorrt_converter (torch .ops .aten .rsub .Tensor )
567
- def aten_ops_rsub (
568
- network : TRTNetwork ,
569
- target : Target ,
570
- args : Tuple [Argument , ...],
571
- kwargs : Dict [str , Argument ],
572
- name : str ,
573
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
574
- alpha = None
575
- if "alpha" in kwargs :
576
- alpha = kwargs ["alpha" ]
577
- return rsub (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], alpha )
578
-
579
-
580
- @tensorrt_converter (torch .ops .aten ._softmax .default )
581
- def aten_ops_softmax (
582
- network : TRTNetwork ,
583
- target : Target ,
584
- args : Tuple [Argument , ...],
585
- kwargs : Dict [str , Argument ],
586
- name : str ,
587
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
588
- return softmax (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
589
-
590
-
591
445
@tensorrt_converter (torch .ops .aten .tanh .default )
592
446
def aten_ops_tanh (
593
447
network : TRTNetwork ,
@@ -596,30 +450,12 @@ def aten_ops_tanh(
596
450
kwargs : Dict [str , Argument ],
597
451
name : str ,
598
452
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
599
- return activation .tanh (
600
- network ,
601
- target ,
602
- SourceIR .ATEN ,
603
- name ,
604
- args [0 ],
605
- )
606
453
607
-
608
- @tensorrt_converter (torch .ops .aten .where .self )
609
- def aten_ops_where (
610
- network : TRTNetwork ,
611
- target : Target ,
612
- args : Tuple [Argument , ...],
613
- kwargs : Dict [str , Argument ],
614
- name : str ,
615
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
616
- return where (
454
+ return activation .tanh (
617
455
network ,
618
456
target ,
619
457
SourceIR .ATEN ,
620
458
name ,
621
- args [1 ],
622
- args [2 ],
623
459
args [0 ],
624
460
)
625
461
@@ -639,25 +475,6 @@ def aten_ops_cat(
639
475
return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
640
476
641
477
642
- @tensorrt_converter (torch .ops .aten .clamp .default )
643
- def aten_ops_clamp (
644
- network : TRTNetwork ,
645
- target : Target ,
646
- args : Tuple [Argument , ...],
647
- kwargs : Dict [str , Argument ],
648
- name : str ,
649
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
650
- return clamp .clamp (
651
- network ,
652
- target ,
653
- SourceIR .ACC ,
654
- name ,
655
- input_val = args [0 ],
656
- min_val = or_none (args , 1 ),
657
- max_val = or_none (args , 2 ),
658
- )
659
-
660
-
661
478
@tensorrt_converter (torch .ops .aten .expand .default )
662
479
def aten_ops_expand (
663
480
network : TRTNetwork ,
@@ -720,17 +537,6 @@ def aten_ops_operator_add(
720
537
return acc_ops_converters .acc_ops_add (network , target , None , kwargs_new , name )
721
538
722
539
723
- @tensorrt_converter (torch .ops .aten .select .int )
724
- def aten_ops_select (
725
- network : TRTNetwork ,
726
- target : Target ,
727
- args : Tuple [Argument , ...],
728
- kwargs : Dict [str , Argument ],
729
- name : str ,
730
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
731
- return select (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ])
732
-
733
-
734
540
@tensorrt_converter (operator .sub )
735
541
def aten_ops_operator_sub (
736
542
network : TRTNetwork ,
@@ -766,27 +572,6 @@ def aten_ops_sym_numel(
766
572
return reduce_layer .get_output (0 )
767
573
768
574
769
- @tensorrt_converter (torch .ops .aten .slice .Tensor )
770
- def aten_ops_slice (
771
- network : TRTNetwork ,
772
- target : Target ,
773
- args : Tuple [Argument , ...],
774
- kwargs : Dict [str , Argument ],
775
- name : str ,
776
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
777
- return slice_op (
778
- network ,
779
- target ,
780
- SourceIR .ATEN ,
781
- name ,
782
- args [0 ],
783
- args [1 ],
784
- args [2 ],
785
- args [3 ],
786
- args [4 ],
787
- )
788
-
789
-
790
575
@tensorrt_converter (torch .ops .aten .sym_size )
791
576
def aten_ops_sym_size (
792
577
network : TRTNetwork ,
0 commit comments