Skip to content

Commit 31a2cba

Browse files
committed
aten::cat converter moving to impl
1 parent 359b6b7 commit 31a2cba

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def aten_ops_cat(
6060
SourceIR.ATEN,
6161
name,
6262
tensors=args[0],
63-
dim = args_bounds_check(args, 2, 1),
63+
dim=args_bounds_check(args, 2, 1),
6464
)
6565

6666

Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
from typing import Optional, Union, Sequence, Dict
1+
from typing import Dict, Optional, Sequence, Union
22

33
import torch
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
77
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
88

9+
910
def cat(
1011
network: TRTNetwork,
1112
target: Target,
1213
source_ir: Optional[SourceIR],
1314
name: str,
1415
input: TRTNetwork,
1516
dim: int,
16-
1717
) -> Union[TRTTensor, Sequence[TRTTensor]]:
18-
1918
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
2019
raise RuntimeError(
2120
f"cat received inputs {input} that is not part " "of the TensorRT region!"
@@ -26,4 +25,4 @@ def cat(
2625

2726
concat_layer.axis = dim
2827
set_layer_name(concat_layer, target, name + "_gather", source_ir)
29-
return concat_layer.get_output(0)
28+
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)