Skip to content

Commit 709d626

Browse files
committed
Changing test cases to not get get_item trace and including test case for non continuous indices
1 parent 6bcc48a commit 709d626

File tree

2 files changed

+75
-36
lines changed

2 files changed

+75
-36
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from torch.fx.node import Target
77
from torch_tensorrt.dynamo._SourceIR import SourceIR
88
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9-
from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
broadcastable,
11+
get_trt_tensor,
12+
)
1013
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
1114
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1215
from torch_tensorrt.fx.converters.converter_utils import (
1316
get_positive_dim,
14-
get_trt_tensor,
1517
has_dynamic_shape,
1618
set_layer_name,
1719
to_numpy,
@@ -88,6 +90,7 @@ def index(
8890

8991
# here we need to check if all the index are broadcastable
9092
# if no, then we need to broadcast
93+
input = get_trt_tensor(network, input, name + f"_input_to_tensor")
9194

9295
last_index = None
9396
for i, ind in enumerate(index):
@@ -334,7 +337,7 @@ def index(
334337
concat_final_tensor = concat_final_shape_layer.get_output(0)
335338

336339
reshape_layer = network.add_shuffle(gather_out)
337-
reshape_layer.setInput(1, concat_final_tensor)
340+
reshape_layer.set_input(1, concat_final_tensor)
338341
set_layer_name(
339342
reshape_layer,
340343
target,

tests/py/dynamo/conversion/test_index_aten.py

+69-33
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22

33
import torch
44
import torch.nn as nn
5-
from harness import DispatchTestCase
65
from torch.testing._internal.common_utils import run_tests
76
from torch_tensorrt import Input
87

8+
from .harness import DispatchTestCase
9+
910

1011
class TestIndexConverter(DispatchTestCase):
11-
def test_index_zero(self):
12+
def test_index_zero_two_dim(self):
1213
class TestModule(nn.Module):
14+
def __init__(self):
15+
self.index0 = torch.randint(0, 1, (1, 1))
16+
super().__init__()
17+
1318
def forward(self, x):
1419
index0 = torch.randint(0, 1, (1, 1))
15-
indices = [None, index0]
20+
indices = [None, self.index0]
1621
out = torch.ops.aten.index.Tensor(x, indices)
1722
return out
1823

@@ -23,11 +28,14 @@ def forward(self, x):
2328
expected_ops={torch.ops.aten.index.Tensor},
2429
)
2530

26-
def test_index_zero_index_one(self):
31+
def test_index_zero_index_three_dim(self):
2732
class TestModule(nn.Module):
33+
def __init__(self):
34+
self.index0 = torch.randint(0, 1, (1, 1))
35+
super().__init__()
36+
2837
def forward(self, x):
29-
index0 = torch.randint(0, 1, (1, 1))
30-
indices = [None, index0, None]
38+
indices = [None, self.index0, None]
3139
out = torch.ops.aten.index.Tensor(x, indices)
3240
return out
3341

@@ -38,76 +46,101 @@ def forward(self, x):
3846
expected_ops={torch.ops.aten.index.Tensor},
3947
)
4048

41-
def test_index_zero_index_one_index_two(self):
49+
def test_index_zero_index_one_index_two_three_dim(self):
4250
class TestModule(nn.Module):
51+
def __init__(self):
52+
self.index0 = torch.randint(0, 1, (1, 1))
53+
self.index1 = torch.randint(0, 1, (1, 1))
54+
super().__init__()
55+
4356
def forward(self, x):
44-
index0 = torch.randint(0, 1, (1, 1))
45-
index1 = torch.randint(0, 1, (1, 1))
46-
indices = [None, index0, index1]
57+
indices = [None, self.index0, self.index1]
4758
out = torch.ops.aten.index.Tensor(x, indices)
4859
return out
4960

5061
input = [torch.randn(2, 2, 2)]
5162
self.run_test(
5263
TestModule(),
5364
input,
54-
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
65+
expected_ops={torch.ops.aten.index.Tensor},
5566
)
5667

57-
def test_index_zero_index_one_SD(self):
68+
def test_index_zero_index_one_four_dim(self):
5869
class TestModule(nn.Module):
70+
def __init__(self):
71+
self.index0 = torch.tensor([0, 0, 1, 1])
72+
self.index1 = torch.tensor([0, 0, 1, 1])
73+
super().__init__()
74+
5975
def forward(self, x):
60-
index0 = torch.tensor([0, 0, 1, 1])
61-
index1 = torch.tensor([0, 0, 1, 1])
62-
indices = [None, index0, index1, None]
76+
indices = [None, self.index0, self.index1, None]
6377
out = torch.ops.aten.index.Tensor(x, indices)
6478
return out
6579

6680
input = [torch.randn(2, 4, 4, 2)]
6781
self.run_test(
6882
TestModule(),
6983
input,
70-
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
84+
expected_ops={torch.ops.aten.index.Tensor},
7185
)
7286

73-
def test_index_zero_index_one_SD(self):
87+
def test_index_zero_index_one_four_dim_SD(self):
7488
class TestModule(nn.Module):
89+
def __init__(self):
90+
self.index0 = torch.tensor(
91+
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
92+
)
93+
self.index1 = torch.tensor(
94+
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
95+
)
96+
super().__init__()
97+
7598
def forward(self, x):
76-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
77-
index1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
78-
indices = [None, index0, index1, None]
99+
indices = [None, self.index0, self.index1, None]
79100
out = torch.ops.aten.index.Tensor(x, indices)
80101
return out
81102

82103
input = [torch.randn(2, 1280, 8, 8)]
83104
self.run_test(
84105
TestModule(),
85106
input,
86-
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
107+
expected_ops={torch.ops.aten.index.Tensor},
87108
)
88109

89-
def test_index_zero_index_one_SD(self):
110+
def test_index_one_SD_unsqueeze_four_dim(self):
90111
class TestModule(nn.Module):
112+
def __init__(self):
113+
self.index0 = torch.tensor(
114+
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
115+
)
116+
self.index1 = self.index0.unsqueeze(0).T.long()
117+
super().__init__()
118+
91119
def forward(self, x):
92-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
93-
index1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
94-
indices = [None, index0, index1, None]
120+
indices = [None, None, self.index1, self.index1]
95121
out = torch.ops.aten.index.Tensor(x, indices)
96122
return out
97123

98124
input = [torch.randn(2, 1280, 8, 8)]
99125
self.run_test(
100126
TestModule(),
101127
input,
102-
expected_ops={torch.ops.aten.index.Tensor, operator.getitem},
128+
expected_ops={torch.ops.aten.index.Tensor},
103129
)
104130

105-
def test_index_zero_index_one_SD_unsqueeze(self):
131+
def test_index_zero_index_one_index_two_SD_unsqueeze_four_dim_broadcast(self):
106132
class TestModule(nn.Module):
133+
def __init__(self):
134+
self.index0 = torch.tensor(
135+
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
136+
)
137+
self.index1 = self.index0.unsqueeze(0).T.long()
138+
super().__init__()
139+
107140
def forward(self, x):
108141
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
109142
index1 = index0.unsqueeze(0).T.long()
110-
indices = [None, None, index1, index1]
143+
indices = [None, None, self.index0, self.index1]
111144
out = torch.ops.aten.index.Tensor(x, indices)
112145
return out
113146

@@ -118,16 +151,19 @@ def forward(self, x):
118151
expected_ops={torch.ops.aten.index.Tensor},
119152
)
120153

121-
def test_index_zero_index_one_index_two_SD_unsqueeze(self):
154+
def test_index_zero_index_one_index_four_dim_non_continuous(self):
122155
class TestModule(nn.Module):
156+
def __init__(self):
157+
self.index0 = torch.tensor([0, 0, 1, 1])
158+
self.index1 = torch.tensor([0, 0, 1, 1])
159+
super().__init__()
160+
123161
def forward(self, x):
124-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
125-
index1 = index0.unsqueeze(0).T.long()
126-
indices = [None, None, index0, index1]
162+
indices = [None, self.index0, None, self.index1]
127163
out = torch.ops.aten.index.Tensor(x, indices)
128164
return out
129165

130-
input = [torch.randn(2, 1280, 8, 8)]
166+
input = [torch.randn(2, 4, 4, 2)]
131167
self.run_test(
132168
TestModule(),
133169
input,

0 commit comments

Comments
 (0)