Skip to content

Commit e028205

Browse files
YuchenJinaltanh
andcommitted
Redesign IRBuilder to BlockBuilder (apache#22)
* init * update * update * test case working * update and add multi block test case * check in * fixes * fix * update * add * update * add * update * address comments. Co-authored-by: Altan Haan <ahaan@octoml.ai>
1 parent acffe5f commit e028205

File tree

15 files changed

+935
-1171
lines changed

15 files changed

+935
-1171
lines changed

include/tvm/relax/block_builder.h

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/block_builder.h
22+
* \brief The utility for constructing Relax binding blocks.
23+
*/
24+
#ifndef TVM_RELAX_BLOCK_BUILDER_H_
25+
#define TVM_RELAX_BLOCK_BUILDER_H_
26+
27+
#include <tvm/ir/expr.h>
28+
#include <tvm/relax/expr.h>
29+
#include <tvm/relay/expr.h>
30+
#include <tvm/runtime/object.h>
31+
#include <tvm/runtime/registry.h>
32+
#include <tvm/support/with.h>
33+
34+
#include <memory>
35+
36+
namespace tvm {
37+
namespace relax {
38+
39+
class BlockBuilder;
40+
41+
/*!
42+
* \brief Utility data structure for generating unique names for IR construction.
43+
*/
44+
class NameTable {
45+
public:
46+
/*!
47+
* \brief Generate a unique name with a specified prefix.
48+
* \param prefix The name prefix.
49+
* \return The generated name.
50+
*/
51+
inline std::string GetUniqueName(std::string prefix) {
52+
std::replace(prefix.begin(), prefix.end(), '.', '_');
53+
std::string unique_prefix = prefix;
54+
auto it = alloc_map_.find(prefix);
55+
if (it != alloc_map_.end()) {
56+
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
57+
}
58+
}
59+
alloc_map_[unique_prefix] = 0;
60+
return unique_prefix;
61+
}
62+
63+
private:
64+
std::unordered_map<std::string, uint32_t> alloc_map_;
65+
};
66+
67+
/*!
68+
* \brief A builder that provides APIs to build Relax binding blocks.
69+
*/
70+
class BlockBuilderNode : public Object {
71+
public:
72+
BlockBuilderNode(std::shared_ptr<NameTable> name_table) : name_table_(name_table) {}
73+
74+
~BlockBuilderNode();
75+
76+
BlockBuilderNode() { name_table_ = std::make_shared<NameTable>(); }
77+
78+
/*! \brief Begin to build a DataflowBlock. */
79+
void BeginDataflowBlock();
80+
/*! \brief Begin to build a BindingBlock. */
81+
void BeginBindingBlock();
82+
/*!
83+
* \brief End building a BindingBlock.
84+
* \return The BindingBlock being built.
85+
*/
86+
BindingBlock EndBlock();
87+
/*!
88+
* \brief Check if the block being built is DataflowBlock or not.
89+
* \return A boolean that indicates if the block being built is DataflowBlock or not.
90+
*/
91+
inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; }
92+
/*!
93+
* \brief Emits an Expr, and returns the variable it is bound to.
94+
* \param expr The Expr to be emitted.
95+
* \param name_hint Name hint for the bound variable.
96+
* \return The new variable that \p expr is bound to.
97+
*/
98+
virtual Var Emit(const Expr& expr, std::string name_hint = "");
99+
/*!
100+
* \brief Emits a variable binding, and returns the bound Var.
101+
* \param binding The variable binding.
102+
* \return The bound variable.
103+
*/
104+
virtual Var Emit(const VarBinding& binding);
105+
/*!
106+
* \brief Emit a MatchShape.
107+
* \param value The value of the MatchShape to be emitted.
108+
* \param pattern The pattern of the MatchShape to be emitted.
109+
* \param name_hint Name hint for the bound variable.
110+
* \return The variable bound to the MatchShape.
111+
*/
112+
Var EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern, std::string name_hint = "");
113+
/*!
114+
* \brief Emit a MatchShape binding.
115+
* \param binding The MatchShape binding to be emitted.
116+
* \return The variable bound to the MatchShape.
117+
*/
118+
Var EmitMatchShape(const MatchShape& binding);
119+
/*!
120+
* \brief Generate an output for the current dataflow block.
121+
* \param output The output variable of the block.
122+
* \param name_hint Name hint for the bound variable.
123+
* \return The variable bound to \p output.
124+
*/
125+
Var EmitOutput(const Expr& output, std::string name_hint = "");
126+
/*!
127+
* \brief Generate an output for the current dataflow block.
128+
* \param binding The output binding to output.
129+
* \return The variable bound to \p output.
130+
*/
131+
Var EmitOutput(const VarBinding& binding);
132+
/*!
133+
* \brief Lookup a var in the binding table \p var_map_.
134+
* \param var The input var.
135+
* \return The Expr bound to the input \p var.
136+
*/
137+
Expr LookupVar(const Var& var);
138+
/*!
139+
* \brief Check if two shape expressions can be proven equal at compile time.
140+
* \param lhs The input lhs shape.
141+
* \param rhs The input rhs shape.
142+
* \return Whether we can prove lhs shape is the same as the rhs shape.
143+
*/
144+
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);
145+
/*!
146+
* \brief Normalize an Expr to complete its shape and type.
147+
* \param expr The input expr.
148+
* \return The expr with normalized shape and type.
149+
*/
150+
Expr Normalize(const Expr& expr);
151+
/*!
152+
* \brief Create a BlockBuilder.
153+
* \return The created BlockBuilder.
154+
*/
155+
TVM_DLL static BlockBuilder Create();
156+
157+
void VisitAttrs(AttrVisitor* v) {}
158+
159+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
160+
static constexpr const char* _type_key = "relax.BlockBuilder";
161+
TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object);
162+
163+
private:
164+
Var Emit(const Expr& expr, bool is_dataflow, std::string name_hint);
165+
166+
protected:
167+
/*!
168+
* \brief A representation of a block frame.
169+
*
170+
* A block frame is a record containing the bindings needed
171+
* to build a binding block, and a boolean to indicate if the
172+
* block being built is a DataflowBlock or not.
173+
*/
174+
struct BlockFrame {
175+
Array<Binding> bindings;
176+
bool is_dataflow;
177+
};
178+
friend class BlockBuilder;
179+
/*!
180+
* \brief Get the current block frame.
181+
* \return The current block frame.
182+
*/
183+
BlockFrame* CurrentFrame();
184+
/*! \brief A stack to store block frames. */
185+
std::stack<BlockFrame> block_stack_;
186+
/*! \brief A diagnostic context for reporting errors. */
187+
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));
188+
/*! \brief A binding table that maps var to value. */
189+
// TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder
190+
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_map_;
191+
/*! \brief A name table to get unique names for IR construction. */
192+
std::shared_ptr<NameTable> name_table_;
193+
};
194+
195+
class BlockBuilder : public ObjectRef {
196+
public:
197+
TVM_DLL explicit BlockBuilder(std::shared_ptr<NameTable> name_table);
198+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
199+
};
200+
201+
} // namespace relax
202+
} // namespace tvm
203+
204+
#endif // TVM_RELAX_BLOCK_BUILDER_H_

include/tvm/relax/expr_functor.h

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
#include <tvm/ir/error.h>
2929
#include <tvm/node/functor.h>
30+
#include <tvm/relax/block_builder.h>
3031
#include <tvm/relax/expr.h>
31-
#include <tvm/relax/ir_builder.h>
3232
#include <tvm/relay/adt.h>
3333
#include <tvm/relay/expr.h>
3434
#include <tvm/relay/function.h>
@@ -167,6 +167,9 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
167167
virtual void VisitMatchShape(const MatchShape& binding);
168168
virtual void VisitBindingBlock(const BindingBlock& block);
169169
virtual void VisitDataflowBlock(const DataflowBlock& block);
170+
171+
protected:
172+
std::unordered_map<const Object*, size_t> visit_counter_;
170173
};
171174

172175
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
@@ -180,11 +183,22 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
180183
*/
181184
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
182185
public:
186+
ExprMutator() {
187+
name_table_ = std::make_shared<NameTable>();
188+
builder_ = BlockBuilder(name_table_);
189+
}
190+
183191
/*!
184192
* \brief Mutate is alias for VisitExpr
185193
* \return expr.
186194
*/
187-
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
195+
Expr Mutate(const Expr& expr) {
196+
if (memo_.count(expr) == 0) {
197+
memo_[expr] = this->VisitExpr(expr);
198+
}
199+
return Downcast<Expr>(memo_[expr]);
200+
}
201+
188202
Expr VisitExpr(const Expr& expr) override;
189203
Expr VisitExpr_(const ConstantNode* op) override;
190204
Expr VisitExpr_(const TupleNode* op) override;
@@ -208,28 +222,32 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
208222
* visitor for types which transform them appropriately.
209223
*/
210224
virtual Type VisitType(const Type& t);
211-
virtual void VisitBinding(const Binding& binding, IRBuilder& builder);
212-
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
213-
virtual void VisitMatchShape(const MatchShape& binding, IRBuilder& builder);
225+
226+
virtual void VisitBinding(const Binding& binding);
227+
virtual Var VisitVarBinding(const VarBinding& binding);
228+
virtual void VisitMatchShape(const MatchShape& binding);
214229
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
215230
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
216231

217232
protected:
218-
IRBuilder builder_;
233+
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
234+
/*! \brief Look up the value binded to a var. */
235+
Expr LookupVar(Var var);
236+
// A remapping table: pre var -> post var
237+
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
238+
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> memo_;
239+
std::shared_ptr<NameTable> name_table_;
240+
BlockBuilder builder_;
219241
};
220242

243+
// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
221244
/*! \brief Dataflow Graph Rewriting for Custom Rewriting Passes
222245
*/
223246
class DataflowMutator : public ExprMutator {
224247
public:
225-
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
226-
virtual Var VisitVarBinding(const VarBinding& binding, IRBuilder& builder);
248+
void VisitBinding(const Binding& binding) final;
227249

228-
protected:
229-
/*! \brief Look up the value binded to a var. */
230-
Expr LookupVar(Var var);
231-
// A remapping table: pre var -> post var
232-
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> pre_post_var_map_;
250+
virtual Var VisitDataflowVarBinding(const VarBinding& binding);
233251
};
234252

235253
} // namespace relax

0 commit comments

Comments
 (0)