Skip to content

fix: Add support for negative dimensions in reduce #2347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import re
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
Expand Down Expand Up @@ -314,3 +314,41 @@ def get_trt_tensor(
return input_val
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
def get_positive_dim(dim: int, dim_size: int) -> int:
...


@overload
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
...


def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
"""
Given an integer number or tuple that represents dimension(s) in the array,
transform it to a positive integer dim if it's negative. Otherwise, do
nothing.

Args:
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
dim_size (int): The size of the dimension in the array.

Returns:
A positive integer or tuple of integers that represent the same dimension as the given dim.
"""

def positive_dim(d: int) -> int:
if d < 0:
return d % dim_size
return d

return (
positive_dim(dim)
if isinstance(dim, int)
else tuple(positive_dim(d) for d in dim)
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
get_trt_plugin,
has_dynamic_shape,
set_layer_name,
Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


Expand All @@ -22,7 +20,7 @@ def permute(
f"permute received input {input} that is not a TensorRT ITensor"
)

permutation = [get_positive_dim(i, len(input.shape)) for i in permutation]
permutation = get_positive_dim(permutation, len(input.shape))

layer = network.add_shuffle(input)
layer.second_transpose = tuple(permutation)
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
Expand All @@ -17,7 +18,7 @@ def amax(
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Union[int, Tuple[int]],
dim: Union[int, Sequence[int]],
keepdim: bool = False,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
Expand All @@ -28,7 +29,7 @@ def amax(
layer = network.add_reduce(
input_val,
trt.ReduceOperation.MAX,
axes=get_axes_for_reduce_op(dim),
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
Expand All @@ -54,7 +55,7 @@ def sum(
layer = network.add_reduce(
input_val,
trt.ReduceOperation.SUM,
axes=get_axes_for_reduce_op(dim),
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
Expand Down
7 changes: 2 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import numpy as np
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
has_dynamic_shape,
to_numpy,
)
from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape, to_numpy
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor


Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
has_dynamic_shape,
prepend_ones,
set_layer_name,
Expand Down
27 changes: 12 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/impl/squeeze.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, Optional, cast
from typing import Optional, Sequence, Union

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

Expand All @@ -16,19 +14,18 @@ def squeeze(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Optional[Any] = None,
dim: Optional[Union[int, Sequence[int]]] = None,
) -> TRTTensor:
dims = []
if dim is not None:
if isinstance(dim, int):
dims.append(cast(Optional[int], dim))
else:
for dim in dim:
dims.append(cast(Optional[int], dim))

# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
assert dim is not None, "We don't support dim=None right now for squeeze."
dims = []

if isinstance(dim, int):
dims.append(dim)
else:
for dim in dim:
dims.append(dim)

new_dims = []
for dim in dims:
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import (
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
set_layer_name,
get_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

Expand Down
3 changes: 3 additions & 0 deletions tests/py/dynamo/conversion/test_amax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TestAmaxConverter(DispatchTestCase):
((2, 3, 4, 5), 3, True),
((2, 3, 4, 5), 2, False),
((6, 7, 5, 4, 5), 4, False),
((1, 5, 2, 1), -1, True),
]
)
def test_amax_dim_int_default(self, input_shape, dim, keep_dims):
Expand Down Expand Up @@ -53,6 +54,7 @@ def forward(self, x):
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
]
)
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
Expand All @@ -74,6 +76,7 @@ def forward(self, x):
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
]
)
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
Expand Down
3 changes: 3 additions & 0 deletions tests/py/dynamo/conversion/test_sum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def forward(self, x):
((2, 3, 4, 5), 3, True),
((2, 3, 4, 5), None, False),
((6, 7, 5, 4, 5), 4, False),
((1, 5, 2, 1), -3, False),
((1, 5, 2, 3), -2, True),
]
)
def test_sum_dim_int(self, input_shape, dim, keep_dims):
Expand All @@ -53,6 +55,7 @@ def forward(self, x):
((2, 1, 4, 5), None, True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
((6, 7, 5, 4, 5), [1, 3, 4], False),
((6, 7, 5, 4, 5), [-5, -4, -2], False),
]
)
def test_sum_dim_tuple(self, input_shape, dim, keep_dims):
Expand Down