1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one
3+ * or more contributor license agreements. See the NOTICE file
4+ * distributed with this work for additional information
5+ * regarding copyright ownership. The ASF licenses this file
6+ * to you under the Apache License, Version 2.0 (the
7+ * "License"); you may not use this file except in compliance
8+ * with the License. You may obtain a copy of the License at
9+ *
10+ * http://www.apache.org/licenses/LICENSE-2.0
11+ *
12+ * Unless required by applicable law or agreed to in writing,
13+ * software distributed under the License is distributed on an
14+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+ * KIND, either express or implied. See the License for the
16+ * specific language governing permissions and limitations
17+ * under the License.
18+ */
19+
20+ /* !
21+ * \file tvm/relax/ir_functor.h
22+ * \brief A generic visitor for traversing Relax IR nodes.
23+ */
24+ #ifndef TVM_RELAX_IR_FUNCTOR_H_
25+ #define TVM_RELAX_IR_FUNCTOR_H_
26+
27+ #include < tvm/node/functor.h>
28+ #include < tvm/node/node.h>
29+ #include < tvm/relax/expr.h>
30+ #include < tvm/relay/expr.h>
31+
32+ namespace tvm {
33+ namespace relax {
34+
35+ template <typename FType>
36+ class IRFunctor ;
37+
38+ #define IR_FUNCTOR_DEFAULT \
39+ { return VisitNodeDefault_ (op, std::forward<Args>(args)...); }
40+
41+ #define RELAX_IR_FUNCTOR_DISPATCH (OP ) \
42+ vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
43+ return self->VisitNode_ (static_cast <const OP*>(n.get ()), std::forward<Args>(args)...); \
44+ });
45+
46+ template <typename R, typename ... Args>
47+ class IRFunctor <R(const ObjectRef& n, Args...)> {
48+ private:
49+ using TSelf = IRFunctor<R(const ObjectRef& n, Args...)>;
50+ using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
51+
52+ public:
53+ using result_type = R;
54+ virtual ~IRFunctor () {}
55+
56+ R operator ()(const ObjectRef& n, Args... args) {
57+ return VisitNode (n, std::forward<Args>(args)...);
58+ }
59+
60+ virtual R VisitNode (const ObjectRef& n, Args... args) {
61+ ICHECK (n.defined ()) << " Found null pointer node while traversing AST. The previous pass may "
62+ " have generated invalid data." ;
63+ static FType vtable = InitVTable ();
64+ return vtable (n, this , std::forward<Args>(args)...);
65+ }
66+
67+ // IR nodes inherited from Relay
68+ virtual R VisitNode_ (const relay::ConstantNode* op, Args... args) IR_FUNCTOR_DEFAULT;
69+ virtual R VisitNode_ (const relay::TupleNode* op, Args... args) IR_FUNCTOR_DEFAULT;
70+ virtual R VisitNode_ (const relay::GlobalVarNode* op, Args... args) IR_FUNCTOR_DEFAULT;
71+ virtual R VisitNode_ (const relay::CallNode* op, Args... args) IR_FUNCTOR_DEFAULT;
72+ virtual R VisitNode_ (const relay::IfNode* op, Args... args) IR_FUNCTOR_DEFAULT;
73+ virtual R VisitNode_ (const OpNode* op, Args... args) IR_FUNCTOR_DEFAULT;
74+ virtual R VisitNode_ (const relay::TupleGetItemNode* op, Args... args) IR_FUNCTOR_DEFAULT;
75+
76+ // IR nodes introduced by Relax
77+ virtual R VisitNode_ (const relax::VarNode* op, Args... args) IR_FUNCTOR_DEFAULT;
78+ virtual R VisitNode_ (const relax::DataflowVarNode* op, Args... args) IR_FUNCTOR_DEFAULT;
79+ virtual R VisitNode_ (const relax::ShapeExprNode* op, Args... args) IR_FUNCTOR_DEFAULT;
80+ virtual R VisitNode_ (const relax::MatchShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT;
81+ virtual R VisitNode_ (const relax::VarBindingNode* op, Args... args) IR_FUNCTOR_DEFAULT;
82+ virtual R VisitNode_ (const relax::BindingBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT;
83+ virtual R VisitNode_ (const relax::DataflowBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT;
84+ virtual R VisitNode_ (const relax::SeqExprNode* op, Args... args) IR_FUNCTOR_DEFAULT;
85+ virtual R VisitNode_ (const relax::FunctionNode* op, Args... args) IR_FUNCTOR_DEFAULT;
86+ virtual R VisitNode_ (const relax::ExternFuncNode* op, Args... args) IR_FUNCTOR_DEFAULT;
87+
88+ virtual R VisitNodeDefault_ (const Object* op, Args...) {
89+ LOG (FATAL) << " no default visitor implemented for " << op->GetTypeKey ();
90+ throw ;
91+ }
92+
93+ private:
94+ static FType InitVTable () {
95+ FType vtable;
96+ RELAX_IR_FUNCTOR_DISPATCH (relay::ConstantNode);
97+ RELAX_IR_FUNCTOR_DISPATCH (relay::TupleNode);
98+ RELAX_IR_FUNCTOR_DISPATCH (relay::GlobalVarNode);
99+ RELAX_IR_FUNCTOR_DISPATCH (relay::CallNode);
100+ RELAX_IR_FUNCTOR_DISPATCH (relay::IfNode);
101+ RELAX_IR_FUNCTOR_DISPATCH (OpNode);
102+ RELAX_IR_FUNCTOR_DISPATCH (relay::TupleGetItemNode);
103+ RELAX_IR_FUNCTOR_DISPATCH (relax::VarNode);
104+ RELAX_IR_FUNCTOR_DISPATCH (relax::DataflowVarNode);
105+ RELAX_IR_FUNCTOR_DISPATCH (relax::ShapeExprNode);
106+ RELAX_IR_FUNCTOR_DISPATCH (relax::MatchShapeNode);
107+ RELAX_IR_FUNCTOR_DISPATCH (relax::VarBindingNode);
108+ RELAX_IR_FUNCTOR_DISPATCH (relax::BindingBlockNode);
109+ RELAX_IR_FUNCTOR_DISPATCH (relax::DataflowBlockNode);
110+ RELAX_IR_FUNCTOR_DISPATCH (relax::SeqExprNode);
111+ RELAX_IR_FUNCTOR_DISPATCH (relax::FunctionNode);
112+ RELAX_IR_FUNCTOR_DISPATCH (relax::ExternFuncNode);
113+ return vtable;
114+ }
115+ };
116+
117+ } // namespace relax
118+ } // namespace tvm
119+
120+ #endif // TVM_RELAX_IR_FUNCTOR_H_
0 commit comments