forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Relax IRVisitor/IRMuator (apache#10)
* ExprVisitor/ExprMutator for relax nodes. * Update Visitor & Mutator. * Update Mutator. * DataflowMutator interface. * EwiseFMARewriter. * Update fma rewrite and add test. * Update test. * Fix dataflow block dispatching. * Construct new dataflow block with IRBuilder. * VisitBinding return void and mutate internal IRBuilder. * Simplify. * Update emit dataflow output. * Explicit memeory allocation rewrite. * LazyIRBuilder. * Update ExplicitMemMutator. * Overload IRBuilder::Emit to have 3 styles. * Update IRBuilder/IRMutator interfaces and passes. * Add MatchShape binding to IRBuilder. * Improve IRMutator interface; add Normalize and CanProveShapeEqual to IRBuilder * Update EmitMatchShape. Co-authored-by: ZihengJiang <ziheng@apache.org>
- Loading branch information
Showing
20 changed files
with
1,759 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relax/expr_functor.h | ||
* \brief A more powerful visitor which enables defining arbitrary function | ||
* signatures with type based dispatch on first argument. | ||
*/ | ||
#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ | ||
#define TVM_RELAX_EXPR_FUNCTOR_H_ | ||
|
||
#include <tvm/ir/error.h> | ||
#include <tvm/node/functor.h> | ||
#include <tvm/relax/expr.h> | ||
#include <tvm/relax/ir_builder.h> | ||
#include <tvm/relay/adt.h> | ||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/function.h> | ||
#include <tvm/relay/op.h> | ||
|
||
#include <deque> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <utility> | ||
#include <vector> | ||
namespace tvm { | ||
namespace relax { | ||
|
||
/*! | ||
* \brief A dynamical functor that dispatches on in the first Expr argument. | ||
* You can use this as a more powerful Visitor, since it allows you to | ||
* define function signatures of Visit Function. | ||
* | ||
* \sa tvm/ir_functor.h | ||
* | ||
* \tparam FType function signiture | ||
* This type is only defined for FType with function signature R(const Expr&, | ||
* Args...) | ||
*/ | ||
template <typename FType> | ||
class ExprFunctor; | ||
|
||
// functions to be overriden. | ||
#define EXPR_FUNCTOR_DEFAULT \ | ||
{ return VisitExprDefault_(op, std::forward<Args>(args)...); } | ||
|
||
#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ | ||
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ | ||
return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ | ||
}); | ||
|
||
template <typename R, typename... Args> | ||
class ExprFunctor<R(const Expr& n, Args...)> { | ||
private: | ||
using TSelf = ExprFunctor<R(const Expr& n, Args...)>; | ||
using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; | ||
|
||
public: | ||
/*! \brief the result type of this functor */ | ||
using result_type = R; | ||
/*! \brief virtual destructor */ | ||
virtual ~ExprFunctor() {} | ||
/*! | ||
* \brief Same as call. | ||
* \param n The expression node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); } | ||
/*! | ||
* \brief The functor call. | ||
* \param n The expression node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
virtual R VisitExpr(const Expr& n, Args... args) { | ||
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " | ||
"have generated invalid data."; | ||
static FType vtable = InitVTable(); | ||
return vtable(n, this, std::forward<Args>(args)...); | ||
} | ||
// Functions that can be overriden by subclass | ||
virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExprDefault_(const Object* op, Args...) { | ||
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); | ||
throw; | ||
} | ||
|
||
private: | ||
// initialize the vtable. | ||
static FType InitVTable() { | ||
FType vtable; | ||
// Set dispatch | ||
RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); | ||
RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); | ||
return vtable; | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief A simple visitor wrapper around ExprFunctor. | ||
* Recursively visit the content. | ||
* | ||
* ExprVisitor treats Expr as dataflow graph, | ||
* and only visit each Expr node once. | ||
*/ | ||
class ExprVisitor : public ExprFunctor<void(const Expr& n)> { | ||
public: | ||
void VisitExpr(const Expr& expr) override; | ||
void VisitExpr_(const ConstantNode* op) override; | ||
void VisitExpr_(const TupleNode* op) override; | ||
void VisitExpr_(const VarNode* op) override; | ||
void VisitExpr_(const DataflowVarNode* op) override; | ||
void VisitExpr_(const ShapeExprNode* op) override; | ||
void VisitExpr_(const ExternFuncNode* op) override; | ||
void VisitExpr_(const GlobalVarNode* op) override; | ||
void VisitExpr_(const FunctionNode* op) override; | ||
void VisitExpr_(const CallNode* op) override; | ||
void VisitExpr_(const SeqExprNode* op) override; | ||
void VisitExpr_(const IfNode* op) override; | ||
void VisitExpr_(const OpNode* op) override; | ||
void VisitExpr_(const TupleGetItemNode* op) override; | ||
|
||
virtual void VisitType(const Type& t); | ||
virtual void VisitSpan(const Span& span); | ||
virtual void VisitBinding(const Binding& binding); | ||
virtual void VisitVarBinding(const VarBinding& binding); | ||
virtual void VisitMatchShape(const MatchShape& binding); | ||
virtual void VisitBindingBlock(const BindingBlock& block); | ||
virtual void VisitDataflowBlock(const DataflowBlock& block); | ||
}; | ||
|
||
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit); | ||
|
||
/*! | ||
* \brief A wrapper around ExprFunctor which functionally updates the AST. | ||
* | ||
* ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. | ||
* The mutated results are memoized in a map and reused so that | ||
* local transformation on the dataflow preserves the graph structure. | ||
*/ | ||
class ExprMutator : public ExprFunctor<Expr(const Expr&)> { | ||
public: | ||
/*! | ||
* \brief Mutate is alias for VisitExpr | ||
* \return expr. | ||
*/ | ||
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } | ||
Expr VisitExpr(const Expr& expr) override; | ||
Expr VisitExpr_(const ConstantNode* op) override; | ||
Expr VisitExpr_(const TupleNode* op) override; | ||
Expr VisitExpr_(const VarNode* op) override; | ||
Expr VisitExpr_(const DataflowVarNode* op) override; | ||
Expr VisitExpr_(const ShapeExprNode* op) override; | ||
Expr VisitExpr_(const ExternFuncNode* op) override; | ||
Expr VisitExpr_(const GlobalVarNode* op) override; | ||
Expr VisitExpr_(const FunctionNode* op) override; | ||
Expr VisitExpr_(const CallNode* op) override; | ||
Expr VisitExpr_(const SeqExprNode* op) override; | ||
Expr VisitExpr_(const IfNode* op) override; | ||
Expr VisitExpr_(const OpNode* op) override; | ||
Expr VisitExpr_(const TupleGetItemNode* op) override; | ||
|
||
/*! | ||
* \brief Used to visit the types inside of expressions. | ||
* | ||
* Can be overloaded to transform the types in arbitrary | ||
* ways, one way would be to define a sub-class of type | ||
* visitor for types which transform them appropriately. | ||
*/ | ||
virtual Type VisitType(const Type& t); | ||
virtual void VisitBinding(const Binding& binding); | ||
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); | ||
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& ir_builder); | ||
virtual BindingBlock VisitBindingBlock(const BindingBlock& block); | ||
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); | ||
|
||
protected: | ||
LazyIRBuilder irbuilder_; | ||
}; | ||
|
||
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes | ||
*/ | ||
class DataflowMutator : public ExprMutator { | ||
public: | ||
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); | ||
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder); | ||
|
||
protected: | ||
/*! \brief Look up the value binded to a var. */ | ||
Expr LookupVar(Var var); | ||
// A remapping table: pre var -> post var | ||
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> pre_post_var_map_; | ||
}; | ||
|
||
} // namespace relax | ||
} // namespace tvm | ||
#endif // TVM_RELAX_EXPR_FUNCTOR_H_ |
Oops, something went wrong.