Skip to content

Commit

Permalink
[Relay][Pass] Simplify consecutive transpose/layout_transform (apache…
Browse files Browse the repository at this point in the history
…#7656)

* [Relay][Pass] Simplify consecutive transpose/layout_transform

* lint

* fix

* support negative

* comment
  • Loading branch information
comaniac authored and trevor-m committed May 11, 2021
1 parent 71d47ec commit 6e23cbb
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);

Expr MakeStack(Expr data, int axis);

Expr MakeTranspose(Expr data, Array<Integer> axes);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides,
String slice_mode);

Expand Down
94 changes: 94 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,99 @@ class SimplifyReshape : public SimplifyPattern {
DFPattern x_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
*/
class SimplifyTranspose : public SimplifyPattern {
public:
SimplifyTranspose() {
x_ = IsWildcard();
auto trans1 = IsOp("transpose") || IsOp("layout_transform");
auto trans2 = IsOp("transpose") || IsOp("layout_transform");
pattern_ = trans1({trans2({x_})});
}

Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
// Helper function to get the axes from call node attribute
auto get_axes_from_call = [](const Call trans_call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i];
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(trans_call->op)->name;
}
return std::move(attr_axes);
};

auto x = node_map[x_][0];

// Initialize axes
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
Array<Integer> axes;
for (int i = 0; i < ndim; ++i) {
axes.push_back(i);
}

// Collect axes changes from the matched pattern, including two consecutive transposes.
std::vector<std::vector<int>> interm_axes;
Call trans_call = Downcast<Call>(post);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
trans_call = Downcast<Call>(trans_call->args[0]);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));

// Calculate the final axes in reverse order (from root to output)
auto it = interm_axes.rbegin();
while (it != interm_axes.rend()) {
auto interm = *it;

Array<Integer> new_axes;
for (int i = 0; i < ndim; ++i) {
new_axes.push_back(axes[interm[i]]);
}
axes = new_axes;
it++;
}

// Check if the transpose is still required
bool need_transpose = false;
for (int i = 0; i < ndim; ++i) {
if (axes[i] != i) {
need_transpose = true;
break;
}
}

if (need_transpose) {
return MakeTranspose(x, axes);
}
return x;
}

private:
/*! \brief Pattern input */
DFPattern x_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
Expand Down Expand Up @@ -162,6 +255,7 @@ class ExprSimplifier {
public:
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyReshape());
CreateCallback(SimplifyTranspose());
CreateCallback(FullElementwise());
}
template <typename T>
Expand Down
58 changes: 58 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,63 @@ def symbolic():
assert tvm.ir.structural_equal(zz, after)


def test_simplify_transpose():
# Test a series of transpose and layout_transform ops
def before1():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
y = relay.layout_transform(y, "NHWC", "HWCN") # To HWCN
y = relay.transpose(y, axes=[3, 0, 1, 2]) # To NHWC
return relay.Function([x], y)

def expected1():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.transpose(x, axes=[0, 2, 3, 1]) # To NHWC
return relay.Function([x], y)

# Test that all transpose ops can be cancelled
def before2():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
y = relay.transpose(y, axes=[0, 2, 3, 1]) # To NHWC
y = relay.transpose(y, axes=[1, 2, 3, 0]) # To HWCN
y = relay.transpose(y, axes=[3, 2, 0, 1]) # To NCHW
return relay.Function([x], y)

def expected2():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
return relay.Function([x], y)

# Test default axis (reverse) and negative axis
def before3():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
y = relay.transpose(y) # Reverse
y = relay.transpose(y) # Reverse
y = relay.transpose(y, axes=[0, 2, -1, 1])
y = relay.transpose(y) # Reverse
y = relay.transpose(y) # Reverse
return relay.Function([x], y)

def expected3():
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32") # NCHW
y = relay.nn.relu(x)
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format(
after, expected
)


def test_simplify_full_elementwise():
def validate(shape, value, dtype):
def before_left(x, elem_op, full):
Expand Down Expand Up @@ -126,4 +183,5 @@ def after_right(x, elem_op, value):

if __name__ == "__main__":
test_simplify_reshape()
test_simplify_transpose()
test_simplify_full_elementwise()

0 comments on commit 6e23cbb

Please sign in to comment.