@@ -15,6 +15,27 @@ std::unordered_set<NodeKind> broadcasting = {
15
15
kGemm ,
16
16
};
17
17
18
+ bool isNopTranspose (const std::vector<int64_t > & perm) {
19
+ for (size_t i = 0 ; i < perm.size (); i++)
20
+ if (perm[i] != i)
21
+ return false ;
22
+ return true ;
23
+ }
24
+
25
+ // returns a vector `ret` such that transposing by `ret` is equivalent
26
+ // to transposing by `t1` and then by `t2`
27
+ std::vector<int64_t > composeTransposes (const std::vector<int64_t > & t1,
28
+ const std::vector<int64_t > & t2) {
29
+ JIT_ASSERT (t1.size () == t2.size ());
30
+ std::vector<int64_t > ret;
31
+ for (size_t i = 0 ; i < t1.size (); i++) {
32
+ JIT_ASSERT ( t1[i] < t2.size ());
33
+ JIT_ASSERT (t2[t1[i]] < t2.size ());
34
+ ret.push_back (t2[t1[i]]);
35
+ }
36
+ return ret;
37
+ }
38
+
18
39
bool isBroadcasting (Node *node) {
19
40
return broadcasting.count (node->kind ());
20
41
}
@@ -93,13 +114,68 @@ void fuseBroadcast(std::shared_ptr<Graph>& graph) {
93
114
}
94
115
}
95
116
117
+ void fuseConsecutiveTransposes (std::shared_ptr<Graph>& graph) {
118
+ for (auto it = graph->begin (); it != graph->end (); ++it) {
119
+ auto * n = *it;
120
+
121
+ if (n->kind () == kTranspose && n->input ()->kind () == kTranspose ) {
122
+ auto origInput = n->input ();
123
+ n->is_ (kperm, composeTransposes (origInput->is (kperm), n->is (kperm)));
124
+ n->replaceInput (0 , origInput->input ());
125
+ if (origInput->uses ().size () == 0 ) {
126
+ origInput->destroy ();
127
+ }
128
+ continue ;
129
+ }
130
+ }
131
+ }
132
+
133
+ void eliminateNopTranspose (std::shared_ptr<Graph>& graph) {
134
+ for (auto it = graph->begin (); it != graph->end (); ++it) {
135
+ auto * n = *it;
136
+
137
+ if (n->kind () == kTranspose ) {
138
+ if (isNopTranspose (n->is (kperm))) {
139
+ n->replaceAllUsesWith (n->input ());
140
+ it.destroyCurrent ();
141
+ continue ;
142
+ }
143
+ }
144
+ }
145
+ }
146
+
147
+ void fuseTransposeIntoGemm (std::shared_ptr<Graph>& graph) {
148
+ static const std::vector<int64_t > simpleTransPerm ({1 ,0 });
149
+
150
+ for (auto it = graph->begin (); it != graph->end (); ++it) {
151
+ auto * n = *it;
152
+
153
+ if (n->kind () == kGemm ) {
154
+ for (size_t i : {0 ,1 }) {
155
+ auto inp = n->inputs ()[i];
156
+ auto trans = i == 0 ? ktransA : ktransB;
157
+ if (inp->kind () == kTranspose && inp->is (kperm) == simpleTransPerm) {
158
+ n->replaceInput (i, inp->input ());
159
+ n->i_ (trans, n->hasAttribute (trans) ? !n->i (trans) : 1 );
160
+ if (inp->uses ().size () == 0 ) {
161
+ inp->destroy ();
162
+ }
163
+ }
164
+ }
165
+ }
166
+ }
167
+ }
168
+
96
169
// This optimization does ONNX-specific peephole optimizations.
97
170
//
98
171
// At the moment, here are the optimizations it does:
99
172
// - This optimization fuses expand calls into ONNX operators, because it is
100
173
// easier for non-strided backends to more efficiently do broadcasts if this is
101
174
// local information. This optimization is not useful for PyTorch as 'expand'
102
175
// is free.
176
+ // - Fusing of consecutive transposes
177
+ // - Elimiation of NOP transposes
178
+ // - Fusing of transposes into Gemm
103
179
//
104
180
// Before you write an optimization here, ask yourself, "Could I do this
105
181
// optimization on ATen operators"? If so, you should seriously consider
@@ -111,6 +187,9 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
111
187
// TODO: make it easier not to do O(k) iterations over the graph, where
112
188
// k is the number of distinct peephole optimizations
113
189
fuseBroadcast (graph);
190
+ fuseConsecutiveTransposes (graph);
191
+ eliminateNopTranspose (graph);
192
+ fuseTransposeIntoGemm (graph);
114
193
}
115
194
116
195
}}
0 commit comments