Skip to content

Commit 5b0e5fc

Browse files
authored
Merge pull request #2500 from pytorch/view_slice_bugfixes_cherry_pick
cherry-pick: View and slice bugfixes
2 parents 73fefbb + b5dc751 commit 5b0e5fc

File tree

7 files changed

+160
-30
lines changed

7 files changed

+160
-30
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ def aten_ops_select(
687687

688688

689689
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
690+
@enforce_tensor_types(
691+
{
692+
0: (TRTTensor,),
693+
}
694+
)
690695
def aten_ops_slice(
691696
ctx: ConversionContext,
692697
target: Target,
@@ -700,9 +705,9 @@ def aten_ops_slice(
700705
SourceIR.ATEN,
701706
name,
702707
args[0],
703-
args[1],
704-
args[2],
705-
args[3],
708+
args_bounds_check(args, 1, replacement=0),
709+
args_bounds_check(args, 2, replacement=None),
710+
args_bounds_check(args, 3, replacement=None),
706711
args_bounds_check(args, 4, replacement=1),
707712
)
708713

@@ -877,6 +882,11 @@ def aten_ops_clone_copy_placeholder(
877882

878883

879884
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
885+
@enforce_tensor_types(
886+
{
887+
0: (TRTTensor,),
888+
}
889+
)
880890
def aten_ops_expand(
881891
ctx: ConversionContext,
882892
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def get_positive_dim(
339339
) -> Union[int, Tuple[int, ...]]:
340340
"""
341341
Given an integer number or tuple that represents dimension(s) in the array,
342-
transform it to a positive integer dim if it's negative. Otherwise, do
343-
nothing.
342+
transform it to a positive integer dim if it's negative.
343+
Otherwise, truncate it to the dimension size
344344
345345
Args:
346346
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
@@ -353,7 +353,8 @@ def get_positive_dim(
353353
def positive_dim(d: int) -> int:
354354
if d < 0:
355355
return d % dim_size
356-
return d
356+
else:
357+
return min(d, dim_size)
357358

358359
return (
359360
positive_dim(dim)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ def slice_op( # TODO: This should be slice not whatever is in base
2121
name: str,
2222
input: TRTTensor,
2323
dim: int,
24-
start: int,
25-
stop: int,
24+
start: Optional[int],
25+
stop: Optional[int],
2626
step: int,
2727
) -> TRTTensor:
28-
if not isinstance(input, TRTTensor):
29-
raise RuntimeError(
30-
f"slice_tensor received input {input} that is not part "
31-
"of the TensorRT region!"
32-
)
28+
# Special case for start being None
29+
if start is None:
30+
start = 0
31+
32+
# Special case for stop being None
33+
if stop is None:
34+
stop = input.shape[dim]
3335

3436
dim = get_positive_dim(dim, len(input.shape))
3537
start = get_positive_dim(start, input.shape[dim])
@@ -39,9 +41,6 @@ def slice_op( # TODO: This should be slice not whatever is in base
3941
# Check whether slice target dim is dynamic shape dim
4042
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
4143

42-
if stop == 2**63 - 1:
43-
stop = input.shape[dim]
44-
4544
start_slice = [0] * len(input.shape)
4645
start_slice[dim] = start
4746
stride_slice = [1] * len(input.shape)
@@ -62,11 +61,6 @@ def expand(
6261
input_t: TRTTensor,
6362
shape: Shape,
6463
) -> TRTTensor:
65-
if not isinstance(input_t, TRTTensor):
66-
raise RuntimeError(
67-
f"expand received input {input_t} that is not a TensorRT ITensor"
68-
)
69-
7064
shape_rank = len(shape)
7165
initial_tensor_rank = len(input_t.shape)
7266

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1111
from .repair_input_as_output import repair_input_as_output
1212
from .replace_max_pool_with_indices import replace_max_pool_with_indices
13+
from .view_to_reshape import view_to_reshape
1314

1415
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1516
[
@@ -19,6 +20,7 @@
1920
lower_efficient_attention,
2021
fuse_prims_broadcast,
2122
replace_max_pool_with_indices,
23+
view_to_reshape,
2224
]
2325
)
2426

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Callable, List, Sequence, Tuple
3+
4+
import torch
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def view_to_reshape(
13+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14+
) -> torch.fx.GraphModule:
15+
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16+
orig, replacement = view_replacement()
17+
18+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19+
gm = clean_up_graph_after_modifications(gm)
20+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21+
22+
return gm
23+
24+
25+
def view_replacement() -> (
26+
Tuple[
27+
torch.fx.GraphModule,
28+
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
29+
]
30+
):
31+
"""Constructs the original and replacement functions for view"""
32+
33+
# Original graph
34+
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
35+
return torch.ops.aten.view.default(input, shape)
36+
37+
# Replacement graph
38+
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
39+
return torch.ops.aten.reshape.default(input, shape)
40+
41+
return orig, replacement

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from .harness import DispatchTestCase
88

99

10-
class TestSelectConverter(DispatchTestCase):
10+
class TestSliceConverter(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13-
("select_dim_start_stop_step", 0, 0, 7, 2),
14-
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
15-
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
16-
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
17-
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
13+
("slice_dim_start_stop_step", 0, 0, 7, 2),
14+
("slice_dim_start_stop_step_offset", 1, 0, 7, 2),
15+
("slice_dim_start_stop_step_exact", 1, 0, 10, 2),
16+
("slice_dim_start_stop_step_negatives", -3, -2, -1, 1),
17+
("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
18+
("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1),
19+
("slice_dim_start_stop_step_none", 2, None, None, 1),
1820
]
1921
)
2022
def test_slice(self, _, dim, start, stop, step):
@@ -32,12 +34,27 @@ def forward(self, input):
3234
input,
3335
)
3436

37+
def test_slice_empty(self):
38+
class TestModule(torch.nn.Module):
39+
def __init__(self):
40+
super().__init__()
41+
42+
def forward(self, input):
43+
out = torch.ops.aten.slice.Tensor(input)
44+
return out
45+
46+
input = [torch.randn(10, 10, 3, 1)]
47+
self.run_test(
48+
TestModule(),
49+
input,
50+
)
51+
3552

36-
class TestSelectConverterDynamicShape(DispatchTestCase):
53+
class TestSliceConverterDynamicShape(DispatchTestCase):
3754
@parameterized.expand(
3855
[
39-
("select_dim_start_stop_step", 1, 0, 7, 2),
40-
("select_dim_start_stop_step", 1, 0, 10, 2),
56+
("slice_dim_start_stop_step", 1, 0, 7, 2),
57+
("slice_dim_start_stop_step", 1, 0, 10, 2),
4158
]
4259
)
4360
def test_slice(self, _, dim, start, stop, step):

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,5 +267,70 @@ def forward(self, q, k, v):
267267
torch._dynamo.reset()
268268

269269

270+
class TestLowerViewToReshape(TestCase):
271+
def test_view_to_reshape(self):
272+
class ViewToReshape(torch.nn.Module):
273+
def forward(self, input):
274+
out = torch.ops.aten.view.default(input, (1, 1, -1))
275+
return out
276+
277+
inputs = [
278+
torch.rand((3, 4, 5, 32)).cuda(),
279+
]
280+
281+
fx_graph = torch.fx.symbolic_trace(ViewToReshape())
282+
expected_ops = {torch.ops.aten.reshape.default}
283+
unexpected_ops = {
284+
torch.ops.aten.view.default,
285+
}
286+
287+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
288+
fx_graph,
289+
inputs,
290+
expected_ops=expected_ops,
291+
unexpected_ops=unexpected_ops,
292+
min_block_size=1,
293+
)
294+
295+
self.assertEquals(
296+
len(unexpected_ops_seen),
297+
0,
298+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
299+
)
300+
301+
self.assertEquals(
302+
len(expected_ops_unseen),
303+
0,
304+
f"The following expected ops were not encountered: {expected_ops_unseen}",
305+
)
306+
torch._dynamo.reset()
307+
308+
# Validate that the results between Torch and Torch-TRT are similar
309+
optimized_model = torch_tensorrt.compile(
310+
fx_graph,
311+
"torch_compile",
312+
inputs,
313+
min_block_size=1,
314+
pass_through_build_failures=True,
315+
)
316+
optimized_model_results = torch.cat(
317+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
318+
)
319+
torch_model_results = torch.cat(
320+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
321+
)
322+
323+
max_diff = float(
324+
torch.max(torch.abs(optimized_model_results - torch_model_results))
325+
)
326+
self.assertAlmostEqual(
327+
max_diff,
328+
0,
329+
DECIMALS_OF_AGREEMENT,
330+
msg=f"ViewToReshape TRT outputs don't match with the original model.",
331+
)
332+
torch._dynamo.reset()
333+
334+
270335
if __name__ == "__main__":
271336
run_tests()

0 commit comments

Comments
 (0)