2
2
3
3
import torch
4
4
import torch .nn as nn
5
- from harness import DispatchTestCase
6
5
from torch .testing ._internal .common_utils import run_tests
7
6
from torch_tensorrt import Input
8
7
8
+ from .harness import DispatchTestCase
9
+
9
10
10
11
class TestIndexConverter (DispatchTestCase ):
11
- def test_index_zero (self ):
12
+ def test_index_zero_two_dim (self ):
12
13
class TestModule (nn .Module ):
14
+ def __init__ (self ):
15
+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
16
+ super ().__init__ ()
17
+
13
18
def forward (self , x ):
14
19
index0 = torch .randint (0 , 1 , (1 , 1 ))
15
- indices = [None , index0 ]
20
+ indices = [None , self . index0 ]
16
21
out = torch .ops .aten .index .Tensor (x , indices )
17
22
return out
18
23
@@ -23,11 +28,14 @@ def forward(self, x):
23
28
expected_ops = {torch .ops .aten .index .Tensor },
24
29
)
25
30
26
- def test_index_zero_index_one (self ):
31
+ def test_index_zero_index_three_dim (self ):
27
32
class TestModule (nn .Module ):
33
+ def __init__ (self ):
34
+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
35
+ super ().__init__ ()
36
+
28
37
def forward (self , x ):
29
- index0 = torch .randint (0 , 1 , (1 , 1 ))
30
- indices = [None , index0 , None ]
38
+ indices = [None , self .index0 , None ]
31
39
out = torch .ops .aten .index .Tensor (x , indices )
32
40
return out
33
41
@@ -38,76 +46,101 @@ def forward(self, x):
38
46
expected_ops = {torch .ops .aten .index .Tensor },
39
47
)
40
48
41
- def test_index_zero_index_one_index_two (self ):
49
+ def test_index_zero_index_one_index_two_three_dim (self ):
42
50
class TestModule (nn .Module ):
51
+ def __init__ (self ):
52
+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
53
+ self .index1 = torch .randint (0 , 1 , (1 , 1 ))
54
+ super ().__init__ ()
55
+
43
56
def forward (self , x ):
44
- index0 = torch .randint (0 , 1 , (1 , 1 ))
45
- index1 = torch .randint (0 , 1 , (1 , 1 ))
46
- indices = [None , index0 , index1 ]
57
+ indices = [None , self .index0 , self .index1 ]
47
58
out = torch .ops .aten .index .Tensor (x , indices )
48
59
return out
49
60
50
61
input = [torch .randn (2 , 2 , 2 )]
51
62
self .run_test (
52
63
TestModule (),
53
64
input ,
54
- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
65
+ expected_ops = {torch .ops .aten .index .Tensor },
55
66
)
56
67
57
- def test_index_zero_index_one_SD (self ):
68
+ def test_index_zero_index_one_four_dim (self ):
58
69
class TestModule (nn .Module ):
70
+ def __init__ (self ):
71
+ self .index0 = torch .tensor ([0 , 0 , 1 , 1 ])
72
+ self .index1 = torch .tensor ([0 , 0 , 1 , 1 ])
73
+ super ().__init__ ()
74
+
59
75
def forward (self , x ):
60
- index0 = torch .tensor ([0 , 0 , 1 , 1 ])
61
- index1 = torch .tensor ([0 , 0 , 1 , 1 ])
62
- indices = [None , index0 , index1 , None ]
76
+ indices = [None , self .index0 , self .index1 , None ]
63
77
out = torch .ops .aten .index .Tensor (x , indices )
64
78
return out
65
79
66
80
input = [torch .randn (2 , 4 , 4 , 2 )]
67
81
self .run_test (
68
82
TestModule (),
69
83
input ,
70
- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
84
+ expected_ops = {torch .ops .aten .index .Tensor },
71
85
)
72
86
73
- def test_index_zero_index_one_SD (self ):
87
+ def test_index_zero_index_one_four_dim_SD (self ):
74
88
class TestModule (nn .Module ):
89
+ def __init__ (self ):
90
+ self .index0 = torch .tensor (
91
+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
92
+ )
93
+ self .index1 = torch .tensor (
94
+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
95
+ )
96
+ super ().__init__ ()
97
+
75
98
def forward (self , x ):
76
- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
77
- index1 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
78
- indices = [None , index0 , index1 , None ]
99
+ indices = [None , self .index0 , self .index1 , None ]
79
100
out = torch .ops .aten .index .Tensor (x , indices )
80
101
return out
81
102
82
103
input = [torch .randn (2 , 1280 , 8 , 8 )]
83
104
self .run_test (
84
105
TestModule (),
85
106
input ,
86
- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
107
+ expected_ops = {torch .ops .aten .index .Tensor },
87
108
)
88
109
89
- def test_index_zero_index_one_SD (self ):
110
+ def test_index_one_SD_unsqueeze_four_dim (self ):
90
111
class TestModule (nn .Module ):
112
+ def __init__ (self ):
113
+ self .index0 = torch .tensor (
114
+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
115
+ )
116
+ self .index1 = self .index0 .unsqueeze (0 ).T .long ()
117
+ super ().__init__ ()
118
+
91
119
def forward (self , x ):
92
- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
93
- index1 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
94
- indices = [None , index0 , index1 , None ]
120
+ indices = [None , None , self .index1 , self .index1 ]
95
121
out = torch .ops .aten .index .Tensor (x , indices )
96
122
return out
97
123
98
124
input = [torch .randn (2 , 1280 , 8 , 8 )]
99
125
self .run_test (
100
126
TestModule (),
101
127
input ,
102
- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
128
+ expected_ops = {torch .ops .aten .index .Tensor },
103
129
)
104
130
105
- def test_index_zero_index_one_SD_unsqueeze (self ):
131
+ def test_index_zero_index_one_index_two_SD_unsqueeze_four_dim_broadcast (self ):
106
132
class TestModule (nn .Module ):
133
+ def __init__ (self ):
134
+ self .index0 = torch .tensor (
135
+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
136
+ )
137
+ self .index1 = self .index0 .unsqueeze (0 ).T .long ()
138
+ super ().__init__ ()
139
+
107
140
def forward (self , x ):
108
141
index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
109
142
index1 = index0 .unsqueeze (0 ).T .long ()
110
- indices = [None , None , index1 , index1 ]
143
+ indices = [None , None , self . index0 , self . index1 ]
111
144
out = torch .ops .aten .index .Tensor (x , indices )
112
145
return out
113
146
@@ -118,16 +151,19 @@ def forward(self, x):
118
151
expected_ops = {torch .ops .aten .index .Tensor },
119
152
)
120
153
121
- def test_index_zero_index_one_index_two_SD_unsqueeze (self ):
154
+ def test_index_zero_index_one_index_four_dim_non_continuous (self ):
122
155
class TestModule (nn .Module ):
156
+ def __init__ (self ):
157
+ self .index0 = torch .tensor ([0 , 0 , 1 , 1 ])
158
+ self .index1 = torch .tensor ([0 , 0 , 1 , 1 ])
159
+ super ().__init__ ()
160
+
123
161
def forward (self , x ):
124
- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
125
- index1 = index0 .unsqueeze (0 ).T .long ()
126
- indices = [None , None , index0 , index1 ]
162
+ indices = [None , self .index0 , None , self .index1 ]
127
163
out = torch .ops .aten .index .Tensor (x , indices )
128
164
return out
129
165
130
- input = [torch .randn (2 , 1280 , 8 , 8 )]
166
+ input = [torch .randn (2 , 4 , 4 , 2 )]
131
167
self .run_test (
132
168
TestModule (),
133
169
input ,
0 commit comments