Skip to content

Commit 7f96986

Browse files
authored
[Relay][Pass] Simplify consecutive transpose/layout_transform (#7656)
* [Relay][Pass] Simplify consecutive transpose/layout_transform * lint * fix * support negative * comment
1 parent 068fed9 commit 7f96986

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

src/relay/op/make_op.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);
7575

7676
Expr MakeStack(Expr data, int axis);
7777

78+
Expr MakeTranspose(Expr data, Array<Integer> axes);
79+
7880
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides,
7981
String slice_mode);
8082

src/relay/transforms/simplify_expr.cc

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,99 @@ class SimplifyReshape : public SimplifyPattern {
8282
DFPattern x_;
8383
};
8484

85+
/*!
86+
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
87+
* and merges or cancels them.
88+
*/
89+
class SimplifyTranspose : public SimplifyPattern {
90+
public:
91+
SimplifyTranspose() {
92+
x_ = IsWildcard();
93+
auto trans1 = IsOp("transpose") || IsOp("layout_transform");
94+
auto trans2 = IsOp("transpose") || IsOp("layout_transform");
95+
pattern_ = trans1({trans2({x_})});
96+
}
97+
98+
Expr callback(const Expr& pre, const Expr& post,
99+
const Map<DFPattern, Array<Expr>>& node_map) const override {
100+
// Helper function to get the axes from call node attribute
101+
auto get_axes_from_call = [](const Call trans_call, int ndim) {
102+
std::vector<int> attr_axes;
103+
if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
104+
if (attr->axes.defined()) {
105+
for (int i = 0; i < ndim; ++i) {
106+
int64_t axis = attr->axes[i];
107+
axis += (axis < 0) ? ndim : 0;
108+
attr_axes.push_back(axis);
109+
}
110+
} else {
111+
// Empty axes means reverse
112+
for (int i = ndim - 1; i >= 0; --i) {
113+
attr_axes.push_back(i);
114+
}
115+
}
116+
} else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
117+
Layout src_layout(attr->src_layout);
118+
Layout dst_layout(attr->dst_layout);
119+
for (int i = 0; i < ndim; ++i) {
120+
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
121+
}
122+
} else {
123+
CHECK(false) << "Expected transpose or layout_transform, but got "
124+
<< Downcast<Op>(trans_call->op)->name;
125+
}
126+
return std::move(attr_axes);
127+
};
128+
129+
auto x = node_map[x_][0];
130+
131+
// Initialize axes
132+
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
133+
Array<Integer> axes;
134+
for (int i = 0; i < ndim; ++i) {
135+
axes.push_back(i);
136+
}
137+
138+
// Collect axes changes from the matched pattern, including two consecutive transposes.
139+
std::vector<std::vector<int>> interm_axes;
140+
Call trans_call = Downcast<Call>(post);
141+
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
142+
trans_call = Downcast<Call>(trans_call->args[0]);
143+
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
144+
145+
// Calculate the final axes in reverse order (from root to output)
146+
auto it = interm_axes.rbegin();
147+
while (it != interm_axes.rend()) {
148+
auto interm = *it;
149+
150+
Array<Integer> new_axes;
151+
for (int i = 0; i < ndim; ++i) {
152+
new_axes.push_back(axes[interm[i]]);
153+
}
154+
axes = new_axes;
155+
it++;
156+
}
157+
158+
// Check if the transpose is still required
159+
bool need_transpose = false;
160+
for (int i = 0; i < ndim; ++i) {
161+
if (axes[i] != i) {
162+
need_transpose = true;
163+
break;
164+
}
165+
}
166+
167+
if (need_transpose) {
168+
return MakeTranspose(x, axes);
169+
}
170+
return x;
171+
}
172+
173+
private:
174+
/*! \brief Pattern input */
175+
DFPattern x_;
176+
};
177+
85178
/*!
86179
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
87180
*/
@@ -162,6 +255,7 @@ class ExprSimplifier {
162255
public:
163256
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
164257
CreateCallback(SimplifyReshape());
258+
CreateCallback(SimplifyTranspose());
165259
CreateCallback(FullElementwise());
166260
}
167261
template <typename T>

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,63 @@ def symbolic():
6060
assert tvm.ir.structural_equal(zz, after)
6161

6262

63+
def test_simplify_transpose():
64+
# Test a series of transpose and layout_transform ops
65+
def before1():
66+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
67+
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
68+
y = relay.layout_transform(y, "NHWC", "HWCN") # To HWCN
69+
y = relay.transpose(y, axes=[3, 0, 1, 2]) # To NHWC
70+
return relay.Function([x], y)
71+
72+
def expected1():
73+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
74+
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
75+
return relay.Function([x], y)
76+
77+
# Test that all transpose ops can be cancelled
78+
def before2():
79+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
80+
y = relay.nn.relu(x)
81+
y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC
82+
y = relay.transpose(y, axes=[1, 2, 3, 0]) # To HWCN
83+
y = relay.transpose(y, axes=[3, 2, 0, 1]) # To NCHW
84+
return relay.Function([x], y)
85+
86+
def expected2():
87+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
88+
y = relay.nn.relu(x)
89+
return relay.Function([x], y)
90+
91+
# Test default axis (reverse) and negative axis
92+
def before3():
93+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
94+
y = relay.nn.relu(x)
95+
y = relay.transpose(y) # Reverse
96+
y = relay.transpose(y) # Reverse
97+
y = relay.transpose(y, axes=[0, 2, -1, 1])
98+
y = relay.transpose(y) # Reverse
99+
y = relay.transpose(y) # Reverse
100+
return relay.Function([x], y)
101+
102+
def expected3():
103+
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
104+
y = relay.nn.relu(x)
105+
y = relay.transpose(y, axes=[0, 2, 3, 1])
106+
return relay.Function([x], y)
107+
108+
for before, expected in [
109+
[before1(), expected1()],
110+
[before2(), expected2()],
111+
[before3(), expected3()],
112+
]:
113+
after = run_opt_pass(before, transform.SimplifyExpr())
114+
expected = run_opt_pass(expected, transform.InferType())
115+
assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format(
116+
after, expected
117+
)
118+
119+
63120
def test_simplify_full_elementwise():
64121
def validate(shape, value, dtype):
65122
def before_left(x, elem_op, full):
@@ -126,4 +183,5 @@ def after_right(x, elem_op, value):
126183

127184
if __name__ == "__main__":
128185
test_simplify_reshape()
186+
test_simplify_transpose()
129187
test_simplify_full_elementwise()

0 commit comments

Comments
 (0)