11import logging
22from typing import Any , Dict , Optional , Sequence , Tuple , Union
33
4- import tensorrt as trt
54import torch
65from torch .fx .node import Argument , Node , Target
76from torch_tensorrt .dynamo ._SourceIR import SourceIR
87from torch_tensorrt .dynamo .conversion import impl
9- from torch_tensorrt .dynamo .conversion .converter_utils import (
10- cast_int_int_div_trt_tensor ,
11- cast_trt_tensor ,
12- )
13- from torch_tensorrt .fx .converters import acc_ops_converters
148from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
159
1610from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
4842 )
4943
5044
51- @dynamo_tensorrt_converter (torch .ops .aten .div .default ) # type: ignore[misc]
52- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode ) # type: ignore[misc]
53- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor ) # type: ignore[misc]
54- def aten_ops_div (
55- network : TRTNetwork ,
56- target : Target ,
57- args : Tuple [Argument , ...],
58- kwargs : Dict [str , Argument ],
59- name : str ,
60- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
61- kwargs_new = {
62- "input" : args [0 ],
63- "other" : args [1 ],
64- }
65- # If both are TRTTensor, both are cast to float32
66- if isinstance (args [0 ], TRTTensor ) and isinstance (args [1 ], TRTTensor ):
67- kwargs_new ["input" ], kwargs_new ["other" ] = cast_int_int_div_trt_tensor (
68- network ,
69- kwargs_new ["input" ],
70- kwargs_new ["other" ],
71- name ,
72- )
73- # If one is TRTTensor, it is cast to float32
74- elif isinstance (args [0 ], TRTTensor ) and (
75- kwargs_new ["input" ].dtype == trt .int8 or kwargs_new ["input" ].dtype == trt .int32
76- ):
77- kwargs_new ["input" ] = cast_trt_tensor (
78- network , kwargs_new ["input" ], trt .float32 , name , target
79- )
80- elif isinstance (args [1 ], TRTTensor ) and (
81- kwargs_new ["other" ].dtype == trt .int8 or kwargs_new ["other" ].dtype == trt .int32
82- ):
83- kwargs_new ["other" ] = cast_trt_tensor (
84- network , kwargs_new ["other" ], trt .float32 , name , target
85- )
86- rounding_mode = kwargs .get ("rounding_mode" )
87- if rounding_mode is None :
88- return acc_ops_converters .acc_ops_div (network , target , None , kwargs_new , name )
89- elif rounding_mode == "floor" :
90- return acc_ops_converters .acc_ops_floor_div (
91- network , target , None , kwargs_new , name
92- )
93- elif rounding_mode == "trunc" :
94- return impl .elementwise .trunc_div (
95- network , target , SourceIR .ATEN , name , args [0 ], args [1 ]
96- )
97- else :
98- raise RuntimeError (
99- f"Target { target } does not support rounding mode { rounding_mode } "
100- )
101-
102-
10345def embedding_param_validator (embedding_node : Node ) -> bool :
10446 scale_grad_by_freq = args_bounds_check (embedding_node .args , 3 )
10547 sparse = args_bounds_check (embedding_node .args , 4 )
@@ -846,24 +788,39 @@ def aten_ops_isinf(
846788
847789
848790@dynamo_tensorrt_converter (torch .ops .aten .add .Tensor )
791+ @dynamo_tensorrt_converter (torch .ops .aten .add .Scalar )
849792def aten_ops_add (
850793 network : TRTNetwork ,
851794 target : Target ,
852795 args : Tuple [Argument , ...],
853796 kwargs : Dict [str , Argument ],
854797 name : str ,
855798) -> Union [TRTTensor , Sequence [TRTTensor ]]:
799+ other = args [1 ]
800+ alpha = kwargs .get ("alpha" , 1 )
801+
802+ if alpha != 1 :
803+ other = impl .elementwise .mul (
804+ network ,
805+ target ,
806+ SourceIR .ATEN ,
807+ name ,
808+ other ,
809+ alpha ,
810+ )
811+
856812 return impl .elementwise .add (
857813 network ,
858814 target ,
859815 SourceIR .ATEN ,
860816 name ,
861817 args [0 ],
862- args [ 1 ] ,
818+ other ,
863819 )
864820
865821
866822@dynamo_tensorrt_converter (torch .ops .aten .mul .Tensor )
823+ @dynamo_tensorrt_converter (torch .ops .aten .mul .Scalar )
867824def aten_ops_mul (
868825 network : TRTNetwork ,
869826 target : Target ,
@@ -918,43 +875,86 @@ def aten_ops_min(
918875
919876
920877@dynamo_tensorrt_converter (torch .ops .aten .sub .Tensor )
878+ @dynamo_tensorrt_converter (torch .ops .aten .sub .Scalar )
921879def aten_ops_sub (
922880 network : TRTNetwork ,
923881 target : Target ,
924882 args : Tuple [Argument , ...],
925883 kwargs : Dict [str , Argument ],
926884 name : str ,
927885) -> Union [TRTTensor , Sequence [TRTTensor ]]:
886+ other = args [1 ]
887+ alpha = kwargs .get ("alpha" , 1 )
888+
889+ if alpha != 1 :
890+ other = impl .elementwise .mul (
891+ network ,
892+ target ,
893+ SourceIR .ATEN ,
894+ name ,
895+ other ,
896+ alpha ,
897+ )
898+
928899 return impl .elementwise .sub (
929900 network ,
930901 target ,
931902 SourceIR .ATEN ,
932903 name ,
933904 args [0 ],
934- args [ 1 ] ,
905+ other ,
935906 )
936907
937908
938- # TODO: keep this or line 54...?
939- # @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
940- # def aten_ops_div(
941- # network: TRTNetwork,
942- # target: Target,
943- # args: Tuple[Argument, ...],
944- # kwargs: Dict[str, Argument],
945- # name: str,
946- # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
947- # return impl.elementwise.div(
948- # network,
949- # target,
950- # SourceIR.ATEN,
951- # name,
952- # args[0],
953- # args[1],
954- # )
909+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor )
910+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode )
911+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar )
912+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar_mode )
913+ def aten_ops_div (
914+ network : TRTNetwork ,
915+ target : Target ,
916+ args : Tuple [Argument , ...],
917+ kwargs : Dict [str , Argument ],
918+ name : str ,
919+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
920+ rounding_mode = kwargs .get ("rounding_mode" )
921+
922+ if rounding_mode is None :
923+ return impl .elementwise .div (
924+ network ,
925+ target ,
926+ SourceIR .ATEN ,
927+ name ,
928+ args [0 ],
929+ args [1 ],
930+ )
931+ elif rounding_mode == "floor" :
932+ return impl .elementwise .floor_divide (
933+ network ,
934+ target ,
935+ SourceIR .ATEN ,
936+ name ,
937+ args [0 ],
938+ args [1 ],
939+ )
940+ elif rounding_mode == "trunc" :
941+ return impl .elementwise .trunc_div (
942+ network ,
943+ target ,
944+ SourceIR .ATEN ,
945+ name ,
946+ args [0 ],
947+ args [1 ],
948+ )
949+ else :
950+ raise RuntimeError (
951+ f"Target { target } does not support rounding mode { rounding_mode } "
952+ )
955953
956954
957955@dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Tensor )
956+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Scalar )
957+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Scalar )
958958def aten_ops_pow (
959959 network : TRTNetwork ,
960960 target : Target ,
@@ -973,6 +973,7 @@ def aten_ops_pow(
973973
974974
975975@dynamo_tensorrt_converter (torch .ops .aten .floor_divide .default )
976+ @dynamo_tensorrt_converter (torch .ops .aten .floor_divide .Scalar )
976977def aten_ops_floor_div (
977978 network : TRTNetwork ,
978979 target : Target ,
@@ -1045,6 +1046,7 @@ def aten_ops_logical_xor(
10451046
10461047
10471048@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1049+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
10481050def aten_ops_equal (
10491051 network : TRTNetwork ,
10501052 target : Target ,
@@ -1063,6 +1065,7 @@ def aten_ops_equal(
10631065
10641066
10651067@dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1068+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
10661069def aten_ops_greater (
10671070 network : TRTNetwork ,
10681071 target : Target ,
@@ -1081,6 +1084,7 @@ def aten_ops_greater(
10811084
10821085
10831086@dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1087+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
10841088def aten_ops_less (
10851089 network : TRTNetwork ,
10861090 target : Target ,
0 commit comments