1
1
from typing import Optional
2
2
3
+ import tensorrt as trt
3
4
from torch .fx .node import Target
4
5
from torch_tensorrt .dynamo ._SourceIR import SourceIR
5
6
from torch_tensorrt .fx .converters .converter_utils import (
10
11
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
11
12
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
12
13
13
- import tensorrt as trt
14
-
15
14
16
15
def matrix_multiply (
17
16
network : TRTNetwork ,
@@ -20,6 +19,8 @@ def matrix_multiply(
20
19
name : str ,
21
20
input : TRTTensor ,
22
21
other : TRTTensor ,
22
+ input_matrix_op : trt .MatrixOperation = trt .MatrixOperation .NONE ,
23
+ other_matrix_op : trt .MatrixOperation = trt .MatrixOperation .NONE ,
23
24
) -> TRTTensor :
24
25
if not isinstance (input , trt .tensorrt .ITensor ):
25
26
input = get_trt_tensor (network , input , f"{ name } _input" )
@@ -31,7 +32,6 @@ def matrix_multiply(
31
32
dtype = unified_dtype_converter (input .dtype , Frameworks .TORCH ),
32
33
)
33
34
34
- input_matrix_op = other_matrix_op = trt .MatrixOperation .NONE
35
35
preset_diff = 0
36
36
37
37
if len (input .shape ) == 1 :
@@ -46,5 +46,5 @@ def matrix_multiply(
46
46
network , input , other , f"{ name } _input" , f"{ name } _other" , preset_diff
47
47
)
48
48
layer = network .add_matrix_multiply (input , input_matrix_op , other , other_matrix_op )
49
- set_layer_name (layer , target , name )
49
+ set_layer_name (layer , target , name , source_ir )
50
50
return layer .get_output (0 )
0 commit comments