Skip to content

Commit

Permalink
[PASS] PostOrderVisit (apache#2169)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Nov 26, 2018
1 parent b5e0d79 commit 0a1f3d4
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 13 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
}
};

// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;

TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
8 changes: 8 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);

/*
* \brief Bind function parameters or free variables.
*
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
from .expr import Expr
from .ty import Type

def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
only once.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
return _ir_pass.post_order_visit(expr, fvisit)

def infer_type(expr, mod=None):
"""Infer the type of expr under the context of mod.
Expand Down
30 changes: 30 additions & 0 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,36 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {

void ExprVisitor::VisitType(const Type& t) { return; }


// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
public:
explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
void VisitExpr(const Expr& e) final {
if (visited_.count(e.get()) != 0) return;
visited_.insert(e.get());
ExprVisitor::VisitExpr(e);
f_(e);
}

private:
std::function<void(const Expr&)> f_;
std::unordered_set<const Node*> visited_;
};

void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisit(fvisit).VisitExpr(e);
}

TVM_REGISTER_API("relay._ir_pass.post_order_visit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
PostOrderVisit(args[0], [f](const Expr& n) {
f(n);
});
});


// Implement bind.
class ExprBinder : public ExprMutator {
public:
Expand Down
16 changes: 3 additions & 13 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <topi/elemwise.h>
#include "../type_relations.h"
#include "../op_common.h"
Expand Down Expand Up @@ -89,19 +90,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy")
.add_type_rel("Identity", IdentityRel)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));


// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;

TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};
// relay.clip
TVM_REGISTER_NODE_TYPE(ClipAttrs);

TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
Expand Down

0 comments on commit 0a1f3d4

Please sign in to comment.