Skip to content

Commit 5c4e9ac

Browse files
committed
Key op fixes for failing tests
1 parent 78f000f commit 5c4e9ac

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,15 @@ def layer_norm(
112112
"of the TensorRT region!"
113113
)
114114

115-
gamma = weight.detach().cpu().float().numpy()
115+
gamma = (
116+
weight.detach().cpu().float().numpy()
117+
if isinstance(weight, torch.Tensor)
118+
else weight
119+
)
116120
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
117-
beta = bias.detach().cpu().float().numpy()
121+
beta = (
122+
bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias
123+
)
118124
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
119125
eps_field = trt.PluginField(
120126
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32

py/torch_tensorrt/dynamo/conversion/impl/permutation.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Sequence, cast
1+
from typing import Optional, Sequence, Union, cast
2+
import numpy as np
23

34

45
from torch.fx.node import Target
@@ -8,6 +9,7 @@
89
from torch_tensorrt.fx.converters.converter_utils import (
910
set_layer_name,
1011
get_positive_dim,
12+
get_trt_tensor,
1113
)
1214

1315

@@ -16,9 +18,15 @@ def permute(
1618
target: Target,
1719
source_ir: Optional[SourceIR],
1820
name: str,
19-
input: TRTTensor,
21+
input: Union[TRTTensor, np.ndarray],
2022
permutation: Sequence[int],
2123
) -> TRTTensor:
24+
if isinstance(input, np.ndarray):
25+
tensor_to_freeze = np.transpose(input, permutation)
26+
# TODO: Fix naming for constant tensors
27+
frozen_trt_tensor = get_trt_tensor(network, tensor_to_freeze, name)
28+
return frozen_trt_tensor
29+
2230
if not isinstance(input, TRTTensor):
2331
raise RuntimeError(
2432
f"permute received input {input} that is not a TensorRT ITensor"

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2719,8 +2719,14 @@ def acc_ops_linear(
27192719
"dim for linear and it can't be the last dim."
27202720
)
27212721

2722-
if isinstance(kwargs["weight"], torch.Tensor):
2723-
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
2722+
if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)):
2723+
weight = get_trt_tensor(
2724+
network,
2725+
kwargs["weight"].t()
2726+
if isinstance(kwargs["weight"], torch.Tensor)
2727+
else kwargs["weight"].T,
2728+
f"{name}_weight",
2729+
)
27242730
if target not in (acc_ops.linear, torch.ops.aten.linear):
27252731
weight_op = trt.MatrixOperation.TRANSPOSE
27262732
else:

py/torch_tensorrt/fx/converters/converter_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def create_constant(
271271
"""
272272
constant = network.add_constant(
273273
(1,) if isinstance(value, (int, float)) else value.shape,
274-
to_numpy(value, dtype),
274+
to_numpy(value, dtype).copy(),
275275
)
276276
constant.name = name
277277
return constant.get_output(0)

py/torch_tensorrt/fx/converters/impl/convolution.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def convNd(
5555
)
5656

5757
# Process bias terms
58-
if isinstance(bias, torch.Tensor):
58+
if isinstance(bias, (torch.Tensor, np.ndarray)):
5959
# Transform the bias constant into a Numpy array
6060
bias = to_numpy(bias)
6161

@@ -80,7 +80,7 @@ def convNd(
8080
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
8181
)
8282

83-
elif isinstance(weight, torch.Tensor):
83+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8484
# Transform the weight constant into a Numpy array
8585
weight = to_numpy(weight)
8686

0 commit comments

Comments
 (0)