Skip to content

Commit 841fc86

Browse files
altanhYuchenJin
authored andcommitted
Fixes and improvements (apache#24)
1 parent 3f62418 commit 841fc86

File tree

14 files changed

+156
-86
lines changed

14 files changed

+156
-86
lines changed

include/tvm/relax/block_builder.h

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define TVM_RELAX_BLOCK_BUILDER_H_
2626

2727
#include <tvm/ir/expr.h>
28+
#include <tvm/relax/utils.h>
2829
#include <tvm/relax/expr.h>
2930
#include <tvm/relay/expr.h>
3031
#include <tvm/runtime/object.h>
@@ -38,32 +39,6 @@ namespace relax {
3839

3940
class BlockBuilder;
4041

41-
/*!
42-
* \brief Utility data structure for generating unique names for IR construction.
43-
*/
44-
class NameTable {
45-
public:
46-
/*!
47-
* \brief Generate a unique name with a specified prefix.
48-
* \param prefix The name prefix.
49-
* \return The generated name.
50-
*/
51-
inline std::string GetUniqueName(std::string prefix) {
52-
std::replace(prefix.begin(), prefix.end(), '.', '_');
53-
std::string unique_prefix = prefix;
54-
auto it = alloc_map_.find(prefix);
55-
if (it != alloc_map_.end()) {
56-
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
57-
}
58-
}
59-
alloc_map_[unique_prefix] = 0;
60-
return unique_prefix;
61-
}
62-
63-
private:
64-
std::unordered_map<std::string, uint32_t> alloc_map_;
65-
};
66-
6742
/*!
6843
* \brief A builder that provides APIs to build Relax binding blocks.
6944
*/

include/tvm/relax/expr.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ShapeExprNode : public ExprNode {
5454
void VisitAttrs(AttrVisitor* v) {
5555
v->Visit("values", &values);
5656
v->Visit("shape_", &shape_);
57-
v->Visit("checked_type_", &checked_type_);
57+
v->Visit("_checked_type_", &checked_type_);
5858
v->Visit("span", &span);
5959
}
6060

@@ -94,11 +94,11 @@ class VarNode : public ExprNode {
9494
const String& name_hint() const { return vid->name_hint; }
9595

9696
void VisitAttrs(AttrVisitor* v) {
97+
v->Visit("_checked_type_", &checked_type_);
9798
v->Visit("vid", &vid);
9899
v->Visit("type_annotation", &type_annotation);
99100
v->Visit("span", &span);
100101
v->Visit("shape_", &shape_);
101-
v->Visit("checked_type_", &checked_type_);
102102
}
103103

104104
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
@@ -143,7 +143,7 @@ class DataflowVarNode : public VarNode {
143143
v->Visit("type_annotation", &type_annotation);
144144
v->Visit("span", &span);
145145
v->Visit("shape_", &shape_);
146-
v->Visit("checked_type_", &checked_type_);
146+
v->Visit("_checked_type_", &checked_type_);
147147
}
148148

149149
bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
@@ -330,7 +330,7 @@ class SeqExprNode : public ExprNode {
330330
v->Visit("blocks", &blocks);
331331
v->Visit("body", &body);
332332
v->Visit("shape_", &shape_);
333-
v->Visit("checked_type_", &checked_type_);
333+
v->Visit("_checked_type_", &checked_type_);
334334
v->Visit("span", &span);
335335
}
336336

@@ -378,7 +378,7 @@ class FunctionNode : public BaseFuncNode {
378378
v->Visit("params", &params);
379379
v->Visit("body", &body);
380380
v->Visit("ret_type", &ret_type);
381-
v->Visit("checked_type_", &checked_type_);
381+
v->Visit("_checked_type_", &checked_type_);
382382
v->Visit("shape_", &shape_);
383383
v->Visit("span", &span);
384384
}

include/tvm/relax/expr_functor.h

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
139139
/*!
140140
* \brief A simple visitor wrapper around ExprFunctor.
141141
* Recursively visit the content.
142-
*
143-
* ExprVisitor treats Expr as dataflow graph,
144-
* and only visit each Expr node once.
145142
*/
146143
class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
147144
public:
@@ -167,9 +164,6 @@ class ExprVisitor : public ExprFunctor<void(const Expr& n)> {
167164
virtual void VisitMatchShape(const MatchShape& binding);
168165
virtual void VisitBindingBlock(const BindingBlock& block);
169166
virtual void VisitDataflowBlock(const DataflowBlock& block);
170-
171-
protected:
172-
std::unordered_map<const Object*, size_t> visit_counter_;
173167
};
174168

175169
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
@@ -221,19 +215,48 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
221215
virtual Type VisitType(const Type& t);
222216

223217
virtual void VisitBinding(const Binding& binding);
224-
virtual Var VisitVarBinding(const VarBinding& binding);
218+
virtual void VisitVarBinding(const VarBinding& binding);
225219
virtual void VisitMatchShape(const MatchShape& binding);
226220

227221
virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
228222
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);
229223

230224
protected:
231225
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
232-
/*! \brief Look up the value binded to a var. */
226+
227+
/*! \brief Look up the value of a variable. If the variable is bound, then returns the bound
228+
* value. Otherwise, returns the rewritten expression for the variable.
229+
*/
233230
Expr LookupVar(Var var);
234-
// A remapping table: pre var -> post var
235-
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
236-
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> memo_;
231+
232+
inline void UpdateMemo(Expr pre, Expr post) {
233+
if (const VarNode* var = pre.as<VarNode>()) {
234+
var_memo_[var->vid] = post;
235+
} else {
236+
expr_memo_[pre] = post;
237+
}
238+
}
239+
240+
inline Optional<Expr> LookupMemo(Expr pre) {
241+
if (pre.as<VarNode>()) {
242+
Id vid = Downcast<Var>(pre)->vid;
243+
if (var_memo_.count(vid)) {
244+
return var_memo_[vid];
245+
}
246+
} else {
247+
if (expr_memo_.count(pre)) {
248+
return expr_memo_[pre];
249+
}
250+
}
251+
return NullOpt;
252+
}
253+
254+
/*! \brief Variable memoization table using Id equality */
255+
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
256+
257+
/*! \brief Expr memoization table using pointer equality */
258+
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
259+
237260
std::shared_ptr<NameTable> name_table_;
238261
BlockBuilder builder_;
239262
};
@@ -245,7 +268,7 @@ class DataflowMutator : public ExprMutator {
245268
public:
246269
void VisitBinding(const Binding& binding) final;
247270

248-
virtual Var VisitDataflowVarBinding(const VarBinding& binding);
271+
virtual void VisitDataflowVarBinding(const VarBinding& binding);
249272
};
250273

251274
} // namespace relax

include/tvm/relax/ir_functor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
/*!
2121
* \file tvm/relax/ir_functor.h
22-
* \brief A generic visitor for traversing Relax IR nodes.
22+
* \brief A generic functor for working with Relax IR nodes.
23+
* \sa tvm/relax/expr_functor.h for common IR rewriting use-cases.
2324
*/
2425
#ifndef TVM_RELAX_IR_FUNCTOR_H_
2526
#define TVM_RELAX_IR_FUNCTOR_H_

include/tvm/relax/utils.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/utils.h
22+
* \brief Utility classes and functions for working with the Relax IR.
23+
*/
24+
#ifndef TVM_RELAX_UTILS_H_
25+
#define TVM_RELAX_UTILS_H_
26+
27+
#include <string>
28+
#include <algorithm>
29+
#include <unordered_map>
30+
31+
namespace tvm {
32+
namespace relax {
33+
34+
/*!
35+
* \brief Utility data structure for generating unique names for IR construction.
36+
*/
37+
class NameTable {
38+
public:
39+
/*!
40+
* \brief Generate a unique name with a specified prefix.
41+
* \param prefix The name prefix.
42+
* \return The generated name.
43+
*/
44+
inline std::string GetUniqueName(std::string prefix) {
45+
std::replace(prefix.begin(), prefix.end(), '.', '_');
46+
std::string unique_prefix = prefix;
47+
auto it = alloc_map_.find(prefix);
48+
if (it != alloc_map_.end()) {
49+
while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) {
50+
}
51+
}
52+
alloc_map_[unique_prefix] = 0;
53+
return unique_prefix;
54+
}
55+
56+
private:
57+
std::unordered_map<std::string, uint32_t> alloc_map_;
58+
};
59+
60+
} // namespace relax
61+
} // namespace tvm
62+
63+
#endif // TVM_RELAX_UTILS_H_

python/tvm/autotvm/graph_tuner/utils/traverse_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _traverse_expr(node):
111111
else:
112112
node_entry["inputs"].append([in_node_idx, 0, 0])
113113
infer_out = _infer_type(node)
114-
out_type = infer_out._checked_type_
114+
out_type = infer_out.checked_type_
115115
if isinstance(out_type, TensorType):
116116
node_entry["types"].append(out_type)
117117
elif isinstance(out_type, TupleType):

python/tvm/contrib/target/onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def infer_type(node):
8585
def call_node_infer_type(node):
8686
"""infer the output types of call node"""
8787
infer_out = infer_type(node)
88-
out_type = infer_out._checked_type_
88+
out_type = infer_out.checked_type_
8989
if isinstance(out_type, TensorType):
9090
types = [out_type]
9191
elif isinstance(out_type, TupleType):

python/tvm/ir/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def checked_type(self):
4545
checked_type : tvm.relay.Type
4646
The checked type.
4747
"""
48-
ret = self.checked_type_
48+
ret = self._checked_type_
4949
if ret is None:
5050
raise ValueError("The type checker has not populated the checked_type for this node")
5151
return ret

src/printer/relax_script_printer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424

2525
#include <tvm/ir/type_functor.h>
26-
#include <tvm/relax/block_builder.h>
26+
#include <tvm/relax/utils.h>
2727
#include <tvm/relax/ir_functor.h>
2828

2929
#include <algorithm>
@@ -397,7 +397,7 @@ std::vector<Doc> RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) {
397397
}
398398
} else {
399399
AttrPrinter attr_printer(&kwargs, this);
400-
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&attr_printer);
400+
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitAttrs(&attr_printer);
401401
}
402402
return kwargs;
403403
}

0 commit comments

Comments
 (0)