Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh

notifications:
# Emails are sent to the committer's git-configured email address by default,
email:
on_success: change
on_failure: always
261 changes: 261 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*!
* Copyright (c) 2017 by Contributors
* \file ir_functor_ext.h
* \brief More powerful Visitor that allows define function signatures.
*/
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_

#include <tvm/ir_functor.h>
#include "./ir.h"

namespace tvm {
namespace ir {

/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
* : public ir::ExprFunctor<int(const Expr&, int)> {
* public:
* int VisitExpr_(const Variable* op, int b) final {
* return b;
* }
* int VisitExpr_(const IntImm* op, int b) final {
* return op->value;
* }
* int VisitExpr_(const Add* op, int b) final {
* return Visit(op->a, b) + Visit(op->b, b);
* }
* };
* MyExprFunctor f;
* Var x("x");
* CHECK_EQ(f(x + 1, 2), 3);
* \endcode
*
* \note Why do we need this more powerful Functor:
*
* We often need to implement a transformer tasks.
* Say we want to take Expr and transform it to some analysis result,
* This easily be done incorrectly using plain Visitor. See IRVisitor's
* document for possible error cases.
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture R(const Expr&, Args...)
*/
template<typename FType>
class ExprFunctor;
/*!
* \brief Same as ExprFunctor except it is applied on statements
* \tparam FType The function signature.
*/
template<typename FType>
class StmtFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT { \
return VisitExprDefault_(op, std::forward<Args>(args)...); \
}
#define STMT_FUNCTOR_DEFAULT { \
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
}

#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \

#define IR_STMT_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \

template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(Variable);
IR_EXPR_FUNCTOR_DISPATCH(Load);
IR_EXPR_FUNCTOR_DISPATCH(Let);
IR_EXPR_FUNCTOR_DISPATCH(Call);
IR_EXPR_FUNCTOR_DISPATCH(Add);
IR_EXPR_FUNCTOR_DISPATCH(Sub);
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
IR_EXPR_FUNCTOR_DISPATCH(NE);
IR_EXPR_FUNCTOR_DISPATCH(LT);
IR_EXPR_FUNCTOR_DISPATCH(LE);
IR_EXPR_FUNCTOR_DISPATCH(GT);
IR_EXPR_FUNCTOR_DISPATCH(GE);
IR_EXPR_FUNCTOR_DISPATCH(And);
IR_EXPR_FUNCTOR_DISPATCH(Or);
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
IR_EXPR_FUNCTOR_DISPATCH(Cast);
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
return vtable;
}
};

template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~StmtFunctor() {}
/*!
* \brief Same as call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitStmt(const Stmt& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
IR_STMT_FUNCTOR_DISPATCH(For);
IR_STMT_FUNCTOR_DISPATCH(Allocate);
IR_STMT_FUNCTOR_DISPATCH(Store);
IR_STMT_FUNCTOR_DISPATCH(Free);
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
};

#undef IR_STMT_FUNCTOR_DISPATCH
#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT

} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
62 changes: 0 additions & 62 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);

virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
Expand All @@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
};

/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};

} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
Loading