Skip to content

Commit 57544fc

Browse files
committed
Addressing review comments
1 parent 6cf4d6c commit 57544fc

File tree

1 file changed

+14
-8
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+14
-8
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import Dict, Optional, Sequence, Union
22

3+
import numpy as np
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
SourceIR,
10+
get_positive_dim,
11+
get_trt_tensor,
12+
)
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
914
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1015

1116

@@ -14,15 +19,16 @@ def cat(
1419
target: Target,
1520
source_ir: Optional[SourceIR],
1621
name: str,
17-
input: Union[TRTTensor, Sequence[TRTTensor]],
22+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
1823
dim: int,
1924
) -> Union[TRTTensor, Sequence[TRTTensor]]:
25+
trt_inputs = []
2026
for each_input in input:
21-
if(not isinstance(each_input, TRTTensor)):
22-
each_input = get_trt_tensor(each_input)
23-
concat_layer = ctx.net.add_concatenation(input)
24-
if dim < 0:
25-
dim = len(input[0].shape) + dim
27+
if not isinstance(each_input, TRTTensor):
28+
each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}")
29+
trt_inputs.append(each_input)
30+
concat_layer = ctx.net.add_concatenation(trt_inputs)
31+
dim = get_positive_dim(dim, len(input[0].shape))
2632

2733
concat_layer.axis = dim
2834
set_layer_name(concat_layer, target, name + "_gather", source_ir)

0 commit comments

Comments
 (0)