Skip to content

Commit b0d4c94

Browse files
committed
fix: improve matmul dynamo converter
fix bug add args input_matrix_op and other_matrix_op, support aten.mv.default minor fix
1 parent e49ef6d commit b0d4c94

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def aten_ops_gelu(
171171

172172
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173173
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174+
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
174175
def aten_ops_matmul(
175176
network: TRTNetwork,
176177
target: Target,
@@ -179,7 +180,12 @@ def aten_ops_matmul(
179180
name: str,
180181
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181182
return impl.matmul.matrix_multiply(
182-
network, target, SourceIR.ATEN, name, args[0], args[1]
183+
network,
184+
target,
185+
SourceIR.ATEN,
186+
name,
187+
args[0],
188+
args[1],
183189
)
184190

185191

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.fx.converters.converter_utils import (
@@ -10,8 +11,6 @@
1011
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1112
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1213

13-
import tensorrt as trt
14-
1514

1615
def matrix_multiply(
1716
network: TRTNetwork,
@@ -20,6 +19,8 @@ def matrix_multiply(
2019
name: str,
2120
input: TRTTensor,
2221
other: TRTTensor,
22+
input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
23+
other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
2324
) -> TRTTensor:
2425
if not isinstance(input, trt.tensorrt.ITensor):
2526
input = get_trt_tensor(network, input, f"{name}_input")
@@ -31,7 +32,6 @@ def matrix_multiply(
3132
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
3233
)
3334

34-
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
3535
preset_diff = 0
3636

3737
if len(input.shape) == 1:
@@ -46,5 +46,5 @@ def matrix_multiply(
4646
network, input, other, f"{name}_input", f"{name}_other", preset_diff
4747
)
4848
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
49-
set_layer_name(layer, target, name)
49+
set_layer_name(layer, target, name, source_ir)
5050
return layer.get_output(0)

0 commit comments

Comments
 (0)