Skip to content

Commit af3964a

Browse files
dzhulgakovsoumith
authored andcommitted
Backport transposes optimization to v0.3.0 (pytorch#3994)
* Optimizer: optimize transposes in variety of circumstances (pytorch#3509) * Optimizer: Optimize transposes in variety of circumstances - No-op transposes - Consecutive transposes (fuse them) - Transposes into Gemm (fuse them into transA/transB parameter) * touch up out of date comment * Backporting optimizer changes
1 parent 1645546 commit af3964a

File tree

4 files changed

+83
-1
lines changed

4 files changed

+83
-1
lines changed

torch/csrc/jit/attributes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs
7878

7979
// CRTP so that Node which inherits Attributes can be return for
8080
// method chaining e.g:
81-
// Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5);
81+
// Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5);
8282
// we return Derived* pointers because Nodes are normally held as pointers.
8383
template<typename Derived>
8484
struct Attributes {

torch/csrc/jit/interned_strings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ _(shape) \
7474
_(axes) \
7575
_(group) \
7676
_(inplace) \
77+
_(transA) \
78+
_(transB) \
7779
_(other)
7880

7981
enum BuiltinSymbol {

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@ std::unordered_set<NodeKind> broadcasting = {
1515
kGemm,
1616
};
1717

18+
bool isNopTranspose(const std::vector<int64_t> & perm) {
19+
for (size_t i = 0; i < perm.size(); i++)
20+
if (perm[i] != i)
21+
return false;
22+
return true;
23+
}
24+
25+
// returns a vector `ret` such that transposing by `ret` is equivalent
26+
// to transposing by `t1` and then by `t2`
27+
std::vector<int64_t> composeTransposes(const std::vector<int64_t> & t1,
28+
const std::vector<int64_t> & t2) {
29+
JIT_ASSERT(t1.size() == t2.size());
30+
std::vector<int64_t> ret;
31+
for (size_t i = 0; i < t1.size(); i++) {
32+
JIT_ASSERT( t1[i] < t2.size());
33+
JIT_ASSERT(t2[t1[i]] < t2.size());
34+
ret.push_back(t2[t1[i]]);
35+
}
36+
return ret;
37+
}
38+
1839
bool isBroadcasting(Node *node) {
1940
return broadcasting.count(node->kind());
2041
}
@@ -93,13 +114,68 @@ void fuseBroadcast(std::shared_ptr<Graph>& graph) {
93114
}
94115
}
95116

117+
void fuseConsecutiveTransposes(std::shared_ptr<Graph>& graph) {
118+
for (auto it = graph->begin(); it != graph->end(); ++it) {
119+
auto* n = *it;
120+
121+
if (n->kind() == kTranspose && n->input()->kind() == kTranspose) {
122+
auto origInput = n->input();
123+
n->is_(kperm, composeTransposes(origInput->is(kperm), n->is(kperm)));
124+
n->replaceInput(0, origInput->input());
125+
if (origInput->uses().size() == 0) {
126+
origInput->destroy();
127+
}
128+
continue;
129+
}
130+
}
131+
}
132+
133+
void eliminateNopTranspose(std::shared_ptr<Graph>& graph) {
134+
for (auto it = graph->begin(); it != graph->end(); ++it) {
135+
auto* n = *it;
136+
137+
if (n->kind() == kTranspose) {
138+
if (isNopTranspose(n->is(kperm))) {
139+
n->replaceAllUsesWith(n->input());
140+
it.destroyCurrent();
141+
continue;
142+
}
143+
}
144+
}
145+
}
146+
147+
void fuseTransposeIntoGemm(std::shared_ptr<Graph>& graph) {
148+
static const std::vector<int64_t> simpleTransPerm({1,0});
149+
150+
for (auto it = graph->begin(); it != graph->end(); ++it) {
151+
auto* n = *it;
152+
153+
if (n->kind() == kGemm) {
154+
for (size_t i : {0,1}) {
155+
auto inp = n->inputs()[i];
156+
auto trans = i == 0 ? ktransA : ktransB;
157+
if (inp->kind() == kTranspose && inp->is(kperm) == simpleTransPerm) {
158+
n->replaceInput(i, inp->input());
159+
n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
160+
if (inp->uses().size() == 0) {
161+
inp->destroy();
162+
}
163+
}
164+
}
165+
}
166+
}
167+
}
168+
96169
// This optimization does ONNX-specific peephole optimizations.
97170
//
98171
// At the moment, here are the optimizations it does:
99172
// - This optimization fuses expand calls into ONNX operators, because it is
100173
// easier for non-strided backends to more efficiently do broadcasts if this is
101174
// local information. This optimization is not useful for PyTorch as 'expand'
102175
// is free.
176+
// - Fusing of consecutive transposes
177+
// - Elimiation of NOP transposes
178+
// - Fusing of transposes into Gemm
103179
//
104180
// Before you write an optimization here, ask yourself, "Could I do this
105181
// optimization on ATen operators"? If so, you should seriously consider
@@ -111,6 +187,9 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
111187
// TODO: make it easier not to do O(k) iterations over the graph, where
112188
// k is the number of distinct peephole optimizations
113189
fuseBroadcast(graph);
190+
fuseConsecutiveTransposes(graph);
191+
eliminateNopTranspose(graph);
192+
fuseTransposeIntoGemm(graph);
114193
}
115194

116195
}}

torch/csrc/jit/passes/peephole.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
1313
for (auto it = graph->begin(); it != graph->end(); ++it) {
1414
auto* n = *it;
1515

16+
// eliminate redundant expand
1617
if (n->kind() == kexpand) {
1718
if (n->is(ksize) == n->input()->type()->expect<TensorType>()->sizes()) {
1819
n->replaceAllUsesWith(n->input());

0 commit comments

Comments
 (0)