Skip to content
Prev Previous commit
Next Next commit
clean flatten converter and add tests
  • Loading branch information
zewenli98 committed Oct 5, 2023
commit c0cfeb82eb705f33dd52908d55f6d3ded1e0d93d
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def aten_ops_slice(
{
0: (TRTTensor,),
}
)
) # type: ignore[misc]
def aten_ops_permute(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1394,7 +1394,7 @@ def conv_param_validator(conv_node: Node) -> bool:
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
)
) # type: ignore[misc]
def aten_ops_convolution(
ctx: ConversionContext,
target: Target,
Expand Down
40 changes: 39 additions & 1 deletion tests/py/dynamo/conversion/test_converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
from torch_tensorrt.dynamo.conversion.converter_utils import (
enforce_tensor_types,
flatten_dims,
)
from torch_tensorrt.fx.types import TRTTensor

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
Expand Down Expand Up @@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
enforce_tensor_types({0: (int, bool)})


class TestFlattenDimsEnforcement(TestCase):
@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
inputs = np.random.randn(*input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)

@parameterized.expand(
[
((1, 2), 0, 0, (1, 2)),
((1, 2), 0, 1, (2,)),
((2, 3, 4), 1, 2, (2, 12)),
((2, 3, 4), 0, 1, (6, 4)),
((2, 3, 4), -3, 2, (24,)),
((2, 3, 4, 5), 0, -2, (24, 5)),
((2, 3, 4, 5), -4, -1, (120,)),
]
)
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
inputs = torch.randn(input_shape)
new_shape = flatten_dims(inputs, start_dim, end_dim)
self.assertEqual(new_shape, true_shape)


if __name__ == "__main__":
run_tests()