Skip to content
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,3 +1555,27 @@ def tensorrt_scaled_dot_product_attention(
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_reshape(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.reshape(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
shape=args[1],
)
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,35 @@ def to_numpy(
raise AssertionError(
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
)


def flatten_dims(
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
start_dim: int,
end_dim: int,
) -> Tuple[int, ...]:
"""
Given an input, start and end indices of dimension,
this function will return a flattened new shape.

Args:
input (Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]):
an input value waiting to be flattened
start_dim (int): the first dim to flatten
end_dim (int): the last dim to flatten (this dim is included)

Returns:
Tuple[int]: new_shape
"""
shape = input.shape
dim_size = len(shape)
start_dim = get_positive_dim(start_dim, dim_size)
end_dim = get_positive_dim(end_dim, dim_size)

num_elements = 1
for i in range(start_dim, end_dim + 1):
num_elements *= shape[i]

new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

return new_shape
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
reduce,
select,
shape,
shuffle,
slice,
split,
squeeze,
Expand Down
21 changes: 21 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional, Sequence, Union

from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def reshape(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
40 changes: 39 additions & 1 deletion tests/py/dynamo/conversion/test_converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
from torch_tensorrt.dynamo.conversion.converter_utils import (
enforce_tensor_types,
flatten_dims,
)
from torch_tensorrt.fx.types import TRTTensor

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
Expand Down Expand Up @@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
enforce_tensor_types({0: (int, bool)})


class TestFlattenDimsEnforcement(TestCase):
@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
inputs = np.random.randn(*input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)

@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
inputs = torch.randn(input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)


if __name__ == "__main__":
run_tests()
21 changes: 11 additions & 10 deletions tests/py/dynamo/conversion/test_reshape_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
class TestReshapeConverter(DispatchTestCase):
@parameterized.expand(
[
((-1,),),
((20,),),
((1, 20),),
((1, 10, -1),),
]
Expand All @@ -21,22 +23,22 @@ class TestReshapeConverter(DispatchTestCase):
"Shape tensor supported well in TensorRT 8.5 and later",
)
def test_reshape(self, target_shape):
class TestModule(torch.nn.Module):
def __init__(self, target_shape):
class Reshape(torch.nn.Module):
def __init__(self):
super().__init__()
self.target_shape = target_shape

def forward(self, x):
return torch.ops.aten.view.default(x, self.target_shape)
return torch.ops.aten.view.default(x, target_shape)

inputs = [torch.randn(1, 2, 10)]
self.run_test(
TestModule(target_shape),
Reshape(),
inputs,
)

@parameterized.expand(
[
((-1,),),
((-1, 10),),
((-1, 5),),
((2, 2, -1),),
Expand All @@ -47,13 +49,12 @@ def forward(self, x):
"Shape tensor supported well in TensorRT 8.5 and later",
)
def test_reshape_with_dynamic_shape(self, target_shape):
class TestModule(torch.nn.Module):
def __init__(self, target_shape):
class Reshape(torch.nn.Module):
def __init__(self):
super().__init__()
self.target_shape = target_shape

def forward(self, x):
return torch.ops.aten.view.default(x, self.target_shape)
return torch.ops.aten.view.default(x, target_shape)

input_specs = [
Input(
Expand All @@ -63,7 +64,7 @@ def forward(self, x):
),
]
self.run_test_with_dynamic_shape(
TestModule(target_shape),
Reshape(),
input_specs,
)

Expand Down