Skip to content

Commit b424735

Browse files
borisfomgs-olive
authored andcommitted
aten::unsqueeze impl refactor
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
1 parent b4da15e commit b424735

File tree

4 files changed

+144
-29
lines changed

4 files changed

+144
-29
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+8-29
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch_tensorrt.fx.converters.impl.unary.base import convert_unary
4040
from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape
4141
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
42+
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
4243

4344
_LOGGER: logging.Logger = logging.getLogger(__name__)
4445

@@ -2215,36 +2216,14 @@ def acc_ops_unsqueeze(
22152216
kwargs: Dict[str, Argument],
22162217
name: str,
22172218
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2218-
input_t = kwargs["input"]
2219-
input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
2220-
if not isinstance(input_val, TRTTensor):
2221-
raise RuntimeError(
2222-
f"unsqueeze received input {input_val} that is not part "
2223-
"of the TensorRT region!"
2224-
)
2225-
2226-
dim = cast(int, kwargs["dim"])
2227-
input_shape = input_val.shape
2228-
input_shape_size = (
2229-
len(input_val.shape) + 1
2230-
if network.has_implicit_batch_dimension
2231-
else len(input_val.shape)
2232-
)
2233-
dim = get_positive_dim(dim, input_shape_size + 1)
2234-
2235-
if network.has_implicit_batch_dimension:
2236-
assert dim != 0
2237-
dim -= 1
2238-
2239-
assert (
2240-
len(get_dynamic_dims(input_val.shape)) <= 1
2241-
), "Currently we don't support unsqueeze with more than one dynamic dims."
2242-
layer = network.add_shuffle(input_val)
2243-
layer.reshape_dims = (
2244-
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
2219+
return unsqueeze(
2220+
network,
2221+
target,
2222+
SourceIR.ACC,
2223+
name,
2224+
input_t=kwargs["input"],
2225+
dim=kwargs["dim"],
22452226
)
2246-
set_layer_name(layer, target, name)
2247-
return layer.get_output(0)
22482227

22492228

22502229
@tensorrt_converter(acc_ops.topk)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torch_tensorrt.fx.converters.impl.slice import slice_op
3434
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
3535
from torch_tensorrt.fx.converters.impl.condition import where
36+
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
3637

3738
_LOGGER: logging.Logger = logging.getLogger(__name__)
3839

@@ -485,6 +486,17 @@ def aten_ops_squeeze(
485486
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
486487

487488

489+
@tensorrt_converter(torch.ops.aten.unsqueeze.default)
490+
def aten_ops_unsqueeze(
491+
network: TRTNetwork,
492+
target: Target,
493+
args: Tuple[Argument, ...],
494+
kwargs: Dict[str, Argument],
495+
name: str,
496+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
497+
return unsqueeze(network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1])
498+
499+
488500
@tensorrt_converter(torch.ops.aten.view.default)
489501
def aten_ops_reshape(
490502
network: TRTNetwork,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import operator
2+
import warnings
3+
from typing import Optional, cast, Any
4+
5+
import numpy as np
6+
7+
import tensorrt as trt
8+
import torch
9+
from torch.fx.node import Target
10+
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape
12+
from torch_tensorrt.fx.converters.converter_utils import (
13+
SourceIR,
14+
get_positive_dim,
15+
get_trt_tensor,
16+
)
17+
18+
from torch_tensorrt.fx.converters.converter_utils import (
19+
SourceIR,
20+
get_positive_dim,
21+
set_layer_name,
22+
)
23+
24+
from torch_tensorrt.fx.utils import get_dynamic_dims
25+
26+
27+
def unsqueeze(
28+
network: TRTNetwork,
29+
target: Target,
30+
source_ir: Optional[SourceIR],
31+
name: str,
32+
input_t,
33+
dim,
34+
) -> TRTTensor:
35+
input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
36+
if not isinstance(input_val, TRTTensor):
37+
raise RuntimeError(
38+
f"unsqueeze received input {input_val} that is not part "
39+
"of the TensorRT region!"
40+
)
41+
42+
dim = cast(int, dim)
43+
input_shape = input_val.shape
44+
input_shape_size = (
45+
len(input_val.shape) + 1
46+
if network.has_implicit_batch_dimension
47+
else len(input_val.shape)
48+
)
49+
dim = get_positive_dim(dim, input_shape_size + 1)
50+
51+
if network.has_implicit_batch_dimension:
52+
assert dim != 0
53+
dim -= 1
54+
55+
assert (
56+
len(get_dynamic_dims(input_val.shape)) <= 1
57+
), "Currently we don't support unsqueeze with more than one dynamic dims."
58+
layer = network.add_shuffle(input_val)
59+
layer.reshape_dims = (
60+
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
61+
)
62+
set_layer_name(layer, target, name)
63+
return layer.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.fx
3+
import torch.nn as nn
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
7+
8+
9+
class TestUnsqueeze(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("negative_dim", -2),
13+
("positive_dim", 2),
14+
]
15+
)
16+
def test_unsqueeze(self, _, dim):
17+
class Unsqueeze(nn.Module):
18+
def __init__(self, dim):
19+
super().__init__()
20+
self.dim = dim
21+
22+
def forward(self, x):
23+
return torch.unsqueeze(x, self.dim)
24+
25+
inputs = [torch.randn(1, 2, 3)]
26+
self.run_test(
27+
Unsqueeze(dim), inputs, expected_ops={torch.ops.aten.unsqueeze.default}
28+
)
29+
30+
# Testing with more than one dynamic dims results in following error:
31+
# AssertionError: Currently we don't support unsqueeze with more than one dynamic dims.
32+
33+
@parameterized.expand(
34+
[
35+
("negative_dim_dynamic", -4),
36+
("positive_dim_dynamic", 1),
37+
]
38+
)
39+
def test_unsqueeze_with_dynamic_shape(self, _, dim):
40+
class Unsqueeze(nn.Module):
41+
def __init__(self, dim):
42+
super().__init__()
43+
self.dim = dim
44+
45+
def forward(self, x):
46+
return torch.unsqueeze(x, self.dim)
47+
48+
input_specs = [
49+
InputTensorSpec(
50+
shape=(-1, 2, 3),
51+
dtype=torch.float32,
52+
shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))],
53+
),
54+
]
55+
self.run_test_with_dynamic_shape(
56+
Unsqueeze(dim), input_specs, expected_ops={torch.ops.aten.unsqueeze.default}
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
run_tests()

0 commit comments

Comments
 (0)