1
1
import logging
2
2
from typing import Any , Dict , Optional , Sequence , Tuple , Union
3
3
4
+ import tensorrt as trt
4
5
import torch
5
6
from torch .fx .node import Argument , Node , Target
6
7
from torch_tensorrt .dynamo ._SourceIR import SourceIR
12
13
from torch_tensorrt .fx .converters import acc_ops_converters
13
14
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
14
15
15
- import tensorrt as trt
16
-
17
16
from .converter_registry import dynamo_tensorrt_converter
18
17
19
18
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -76,13 +75,13 @@ def aten_ops_div(
76
75
kwargs_new ["input" ].dtype == trt .int8 or kwargs_new ["input" ].dtype == trt .int32
77
76
):
78
77
kwargs_new ["input" ] = cast_trt_tensor (
79
- network , kwargs_new ["input" ], trt .float32 , name
78
+ network , kwargs_new ["input" ], trt .float32 , name , target
80
79
)
81
80
elif isinstance (args [1 ], TRTTensor ) and (
82
81
kwargs_new ["other" ].dtype == trt .int8 or kwargs_new ["other" ].dtype == trt .int32
83
82
):
84
83
kwargs_new ["other" ] = cast_trt_tensor (
85
- network , kwargs_new ["other" ], trt .float32 , name
84
+ network , kwargs_new ["other" ], trt .float32 , name , target
86
85
)
87
86
rounding_mode = kwargs .get ("rounding_mode" )
88
87
if rounding_mode is None :
@@ -101,7 +100,7 @@ def aten_ops_div(
101
100
)
102
101
103
102
104
- def embedding_param_validator (embedding_node : Node ):
103
+ def embedding_param_validator (embedding_node : Node ) -> bool :
105
104
scale_grad_by_freq = args_bounds_check (embedding_node .args , 3 )
106
105
sparse = args_bounds_check (embedding_node .args , 4 )
107
106
@@ -365,3 +364,59 @@ def aten_ops_permute(
365
364
args [0 ],
366
365
args [1 ],
367
366
)
367
+
368
+
369
+ def to_copy_dtype_validator (to_copy_node : Node ) -> bool :
370
+ allowed_casts = {torch .float , torch .int32 , torch .bool , torch .int8 , torch .float16 }
371
+
372
+ # Validate input node has convertible kwargs
373
+ if "dtype" in to_copy_node .kwargs :
374
+ if to_copy_node .kwargs ["dtype" ] in allowed_casts :
375
+ return True
376
+ else :
377
+ _LOGGER .debug (
378
+ f"_to_copy converter rejected node { to_copy_node } with dtype { to_copy_node .kwargs ['dtype' ]} "
379
+ )
380
+ return False
381
+ else :
382
+ _LOGGER .debug (
383
+ f"_to_copy converter rejected node { to_copy_node } with kwargs { to_copy_node .kwargs } "
384
+ )
385
+ return False
386
+
387
+
388
+ @dynamo_tensorrt_converter (
389
+ torch .ops .aten ._to_copy .default , capability_validator = to_copy_dtype_validator
390
+ )
391
+ def aten_ops_to_copy_dtype (
392
+ network : TRTNetwork ,
393
+ target : Target ,
394
+ args : Tuple [Argument , ...],
395
+ kwargs : Dict [str , Argument ],
396
+ name : str ,
397
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
398
+ return impl .cast .to_copy (
399
+ network ,
400
+ target ,
401
+ SourceIR .ATEN ,
402
+ name ,
403
+ args [0 ],
404
+ kwargs ["dtype" ],
405
+ )
406
+
407
+
408
+ @dynamo_tensorrt_converter (torch .ops .aten .clone .default )
409
+ def aten_ops_clone (
410
+ network : TRTNetwork ,
411
+ target : Target ,
412
+ args : Tuple [Argument , ...],
413
+ kwargs : Dict [str , Argument ],
414
+ name : str ,
415
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
416
+ return impl .cast .clone (
417
+ network ,
418
+ target ,
419
+ SourceIR .ATEN ,
420
+ name ,
421
+ args [0 ],
422
+ )
0 commit comments