1
1
from typing import Dict , Optional , Sequence , Union
2
2
3
+ import numpy as np
3
4
import torch
4
5
from torch .fx .node import Target
5
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
6
- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
7
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8
- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9
+ SourceIR ,
10
+ get_positive_dim ,
11
+ get_trt_tensor ,
12
+ )
13
+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
9
14
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
10
15
11
16
@@ -14,15 +19,16 @@ def cat(
14
19
target : Target ,
15
20
source_ir : Optional [SourceIR ],
16
21
name : str ,
17
- input : Union [TRTTensor , Sequence [ TRTTensor ]],
22
+ input : Sequence [ Union [TRTTensor , torch . Tensor , np . ndarray ]],
18
23
dim : int ,
19
24
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
25
+ trt_inputs = []
20
26
for each_input in input :
21
- if ( not isinstance (each_input , TRTTensor ) ):
22
- each_input = get_trt_tensor (each_input )
23
- concat_layer = ctx . net . add_concatenation ( input )
24
- if dim < 0 :
25
- dim = len (input [0 ].shape ) + dim
27
+ if not isinstance (each_input , TRTTensor ):
28
+ each_input = get_trt_tensor (ctx , each_input , name + "_tensor_{i}" )
29
+ trt_inputs . append ( each_input )
30
+ concat_layer = ctx . net . add_concatenation ( trt_inputs )
31
+ dim = get_positive_dim ( dim , len (input [0 ].shape ))
26
32
27
33
concat_layer .axis = dim
28
34
set_layer_name (concat_layer , target , name + "_gather" , source_ir )
0 commit comments