Skip to content

Commit 6f136de

Browse files
committed
fix dim bugs and update tests
1 parent 016f266 commit 6f136de

File tree

7 files changed

+14
-7
lines changed

7 files changed

+14
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def aten_ops_amax(
645645
SourceIR.ATEN,
646646
name,
647647
args[0],
648-
args[1],
648+
args_bounds_check(args, 1, replacement=[]),
649649
args_bounds_check(args, 2, replacement=False),
650650
)
651651

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def amax(
2727
):
2828
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
2929

30+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
31+
dim = tuple(range(len(input_val.shape)))
32+
3033
layer = ctx.net.add_reduce(
3134
input_val,
3235
trt.ReduceOperation.MAX,
@@ -51,7 +54,7 @@ def sum(
5154
):
5255
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
5356

54-
if dim is None:
57+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5558
dim = tuple(range(len(input_val.shape)))
5659

5760
layer = ctx.net.add_reduce(
@@ -169,7 +172,7 @@ def mean(
169172
):
170173
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
171174

172-
if dim is None:
175+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
173176
dim = tuple(range(len(input_val.shape)))
174177

175178
layer = ctx.net.add_reduce(

tests/py/dynamo/conversion/test_amax_aten.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def forward(self, x):
3030

3131
@parameterized.expand(
3232
[
33+
((1, 2, 4), [], True),
3334
((3, 2, 4), [1], True),
3435
((2, 1, 4, 5), [0, 3], True),
3536
((2, 3, 4, 5), [0, 1, 2, 3], False),
@@ -72,6 +73,7 @@ def forward(self, x):
7273

7374
@parameterized.expand(
7475
[
76+
((1, 2, 4), [], True, torch.int, 0, 5),
7577
((3, 2, 4), [1], True, torch.int, 0, 5),
7678
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
7779
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),

tests/py/dynamo/conversion/test_max_aten.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestMaxConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_min_aten.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestMinConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_prod_aten.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestProdConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)

tests/py/dynamo/conversion/test_sum_aten.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
class TestSumConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12+
((1, 2),),
1213
((3, 2, 4),),
1314
((2, 3, 4, 5),),
14-
((2, 3, 4, 5),),
1515
((6, 7, 5, 4, 5),),
1616
]
1717
)
@@ -51,6 +51,7 @@ def forward(self, x):
5151

5252
@parameterized.expand(
5353
[
54+
((1, 2, 4), [], True),
5455
((3, 2, 4), [1], True),
5556
((2, 1, 4, 5), None, True),
5657
((2, 3, 4, 5), [0, 1, 2, 3], False),
@@ -93,6 +94,7 @@ def forward(self, x):
9394

9495
@parameterized.expand(
9596
[
97+
((1, 2, 4), [], True, torch.int, 0, 5),
9698
((3, 2, 4), [1], True, torch.int, 0, 5),
9799
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
98100
((2, 3, 4, 5), None, False, torch.int32, -5, 0),

0 commit comments

Comments
 (0)