Skip to content

Commit 02bcb25

Browse files
YuchenJinyongwww
authored andcommitted
Visit shape in Visitor/Mutator (apache#45)
1 parent 18385c5 commit 02bcb25

File tree

5 files changed

+154
-68
lines changed

5 files changed

+154
-68
lines changed

include/tvm/relax/expr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ using Expr = RelayExpr;
3535
using ExprNode = RelayExprNode;
3636
using relay::Call;
3737
using relay::CallNode;
38+
using relay::Constant;
3839
using relay::ConstantNode;
3940
using relay::Id;
4041
using relay::If;

include/tvm/relax/expr_functor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
293293
}
294294

295295
/*!
296-
* \brief Create a new var with specified shape and type if it's original shape or type does not
296+
* \brief Create a new var with specified shape and type if the original var's shape or type does not
297297
* match with the specified ones.
298298
* \param var The var to be updated.
299299
* \param shape The specified shape.

src/relax/backend/vm/vm_shape_lower.cc

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,38 +32,6 @@
3232
namespace tvm {
3333
namespace relax {
3434

35-
/*!
36-
* \brief Visitor to apply a function to every Expr it visits. Also applies the function
37-
* to the shape field of the var definition site if the var's shape is a ShapeExpr.
38-
*/
39-
class ExprApplyVisitWithShape : public ExprVisitor {
40-
public:
41-
explicit ExprApplyVisitWithShape(std::function<void(const Expr&)> f) : f_(f) {}
42-
43-
void VisitVarDef(const Var& var) {
44-
if (var.as<DataflowVarNode>()) {
45-
this->VisitExpr(Downcast<DataflowVar>(var));
46-
} else {
47-
this->VisitExpr(var);
48-
}
49-
if (var->shape_.operator bool() && var->shape_.value().as<ShapeExprNode>()) {
50-
f_(Downcast<ShapeExpr>(var->shape_.value()));
51-
}
52-
}
53-
54-
void VisitExpr(const Expr& e) final {
55-
ExprVisitor::VisitExpr(e);
56-
f_(e);
57-
}
58-
59-
private:
60-
std::function<void(const Expr&)> f_;
61-
};
62-
63-
void PostOrderVisitWithShape(const Expr& e, std::function<void(const Expr&)> fvisit) {
64-
ExprApplyVisitWithShape(fvisit).VisitExpr(e);
65-
}
66-
6735
class VMShapeLowerMutator : public ExprMutator {
6836
public:
6937
static DataType ShapeDType() { return DataType::Int(64); };
@@ -125,9 +93,7 @@ class VMShapeLowerMutator : public ExprMutator {
12593
builder_->BeginBindingBlock();
12694
builder_->Emit(VarBinding(
12795
shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})})));
128-
Array<Var> params;
12996
for (Var param : node->params) {
130-
params.push_back(this->VisitVarDef(param));
13197
if (param->shape_.operator bool() && param->shape_.value().as<ShapeExprNode>()) {
13298
Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh");
13399
StoreShape(shape, Downcast<ShapeExpr>(param->shape_.value())->values);
@@ -150,7 +116,7 @@ class VMShapeLowerMutator : public ExprMutator {
150116
blocks.push_back(builder_->EndBlock());
151117
new_body = SeqExpr(blocks, new_body);
152118

153-
return Function(node->name, params, new_body, ret_type);
119+
return Function(node->name, node->params, new_body, ret_type);
154120
}
155121

156122
tir::PrimFunc CalculateShape(ShapeExpr s) {
@@ -201,7 +167,7 @@ class VMShapeLowerMutator : public ExprMutator {
201167
}
202168
}
203169
};
204-
PostOrderVisitWithShape(expr, func);
170+
PostOrderVisit(expr, func);
205171
return ret;
206172
}
207173

src/relax/ir/expr_functor.cc

Lines changed: 123 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
namespace tvm {
3434
namespace relax {
3535

36-
void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); }
36+
void ExprVisitor::VisitExpr_(const ConstantNode* op) {
37+
this->VisitSpan(op->span);
38+
39+
if (op->shape_) {
40+
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
41+
}
42+
}
3743

3844
void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); }
3945

@@ -42,20 +48,20 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) {
4248
for (Expr field : op->fields) {
4349
this->VisitExpr(field);
4450
}
51+
52+
if (op->shape_) {
53+
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
54+
}
4555
}
4656

57+
// Visit the use-site of a defined Var
4758
void ExprVisitor::VisitExpr_(const VarNode* op) {
4859
this->VisitSpan(op->span);
49-
if (op->type_annotation.defined()) {
50-
this->VisitType(op->type_annotation.value());
51-
}
5260
}
5361

62+
// Visit the use-site of a defined DataflowVar
5463
void ExprVisitor::VisitExpr_(const DataflowVarNode* op) {
5564
this->VisitSpan(op->span);
56-
if (op->type_annotation.defined()) {
57-
this->VisitType(op->type_annotation.value());
58-
}
5965
}
6066

6167
void ExprVisitor::VisitExpr_(const FunctionNode* op) {
@@ -78,6 +84,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) {
7884
for (Expr arg : op->args) {
7985
this->VisitExpr(arg);
8086
}
87+
88+
if (op->shape_) {
89+
this->VisitExpr(Downcast<Expr>(op->shape_.value()));
90+
}
8191
}
8292

8393
void ExprVisitor::VisitExpr_(const IfNode* op) {
@@ -142,19 +152,25 @@ void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
142152
if (var->type_annotation.defined()) {
143153
this->VisitType(var->type_annotation.value());
144154
}
155+
156+
if (var->shape_) {
157+
this->VisitExpr(Downcast<Expr>(var->shape_.value()));
158+
}
145159
}
146160

147161
void ExprVisitor::VisitVarDef_(const VarNode* var) {
148162
this->VisitSpan(var->span);
149163
if (var->type_annotation.defined()) {
150164
this->VisitType(var->type_annotation.value());
151165
}
152-
}
153166

154-
void ExprVisitor::VisitExpr(const Expr& expr) {
155-
ExprFunctor::VisitExpr(expr);
167+
if (var->shape_) {
168+
this->VisitExpr(Downcast<Expr>(var->shape_.value()));
169+
}
156170
}
157171

172+
void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); }
173+
158174
void ExprVisitor::VisitBinding(const Binding& binding) {
159175
if (const auto* node = binding.as<VarBindingNode>()) {
160176
VisitBinding_(node);
@@ -209,23 +225,48 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex
209225
// ==================
210226
// ExprMutator
211227

212-
Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); }
228+
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
229+
Expr new_shape;
230+
bool unchanged = true;
231+
if (op->shape_) {
232+
new_shape = this->VisitExpr(Downcast<Expr>(op->shape_.value()));
233+
if (!new_shape.same_as(op->shape_)) {
234+
unchanged = false;
235+
}
236+
}
237+
238+
if (unchanged) {
239+
return GetRef<Expr>(op);
240+
} else {
241+
Expr new_constant = Constant(op->data, op->span);
242+
new_constant->shape_ = new_shape;
243+
return new_constant;
244+
}
245+
}
213246

214247
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); }
215248

216249
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
250+
bool unchanged = true;
217251
tvm::Array<Expr> fields;
218-
bool all_fields_unchanged = true;
219252
for (Expr field : op->fields) {
220253
Expr new_field = this->VisitExpr(field);
221254
fields.push_back(new_field);
222-
all_fields_unchanged &= new_field.same_as(field);
255+
unchanged &= new_field.same_as(field);
256+
}
257+
258+
Expr new_shape;
259+
if (op->shape_) {
260+
new_shape = this->VisitExpr(Downcast<Expr>(op->shape_.value()));
261+
unchanged &= new_shape.same_as(op->shape_);
223262
}
224263

225-
if (all_fields_unchanged) {
264+
if (unchanged) {
226265
return GetRef<Expr>(op);
227266
} else {
228-
return Tuple(fields, op->span);
267+
Expr new_tuple = Tuple(fields, op->span);
268+
new_tuple->shape_ = new_shape;
269+
return new_tuple;
229270
}
230271
}
231272

@@ -288,10 +329,18 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
288329
unchanged &= new_arg.same_as(arg);
289330
}
290331

332+
Expr new_shape;
333+
if (call_node->shape_) {
334+
new_shape = this->VisitExpr(Downcast<Expr>(call_node->shape_.value()));
335+
unchanged &= new_shape.same_as(call_node->shape_);
336+
}
337+
291338
if (unchanged) {
292339
return GetRef<Expr>(call_node);
293340
} else {
294-
return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span);
341+
Expr new_call = Call(new_op, call_args, call_node->attrs, ty_args, call_node->span);
342+
new_call->shape_ = new_shape;
343+
return new_call;
295344
}
296345
}
297346

@@ -424,29 +473,75 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) {
424473
}
425474

426475
Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) {
476+
bool type_unchanged = true;
477+
Type new_type;
427478
if (var->type_annotation.defined()) {
428-
Type type = this->VisitType(var->type_annotation.value());
429-
if (!var->type_annotation.same_as(type)) {
430-
Var new_var = DataflowVar(var->vid, NullOpt, type, var->span);
479+
new_type = this->VisitType(var->type_annotation.value());
480+
type_unchanged &= new_type.same_as(var->type_annotation);
481+
}
482+
483+
bool shape_unchanged = true;
484+
Expr new_shape;
485+
if (var->shape_) {
486+
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
487+
shape_unchanged &= new_shape.same_as(var->shape_);
488+
}
489+
490+
if (type_unchanged && shape_unchanged) {
491+
return GetRef<Var>(var);
492+
} else {
493+
Var new_var;
494+
if (type_unchanged) {
495+
new_var = DataflowVar(var->vid, NullOpt, var->type_annotation, var->span);
496+
} else {
497+
new_var = DataflowVar(var->vid, NullOpt, new_type, var->span);
498+
}
499+
500+
if (shape_unchanged) {
431501
new_var->shape_ = var->shape_;
432-
this->var_remap_[var->vid] = new_var;
433-
return new_var;
502+
} else {
503+
new_var->shape_ = new_shape;
434504
}
505+
506+
this->var_remap_[var->vid] = new_var;
507+
return new_var;
435508
}
436-
return GetRef<Var>(var);
437509
}
438510

439511
Var ExprMutator::VisitVarDef_(const VarNode* var) {
512+
bool type_unchanged = true;
513+
Type new_type;
440514
if (var->type_annotation.defined()) {
441-
Type type = this->VisitType(var->type_annotation.value());
442-
if (!var->type_annotation.same_as(type)) {
443-
Var new_var = Var(var->vid, NullOpt, type, var->span);
515+
new_type = this->VisitType(var->type_annotation.value());
516+
type_unchanged &= new_type.same_as(var->type_annotation);
517+
}
518+
519+
bool shape_unchanged = true;
520+
Expr new_shape;
521+
if (var->shape_) {
522+
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
523+
shape_unchanged &= new_shape.same_as(var->shape_);
524+
}
525+
526+
if (type_unchanged && shape_unchanged) {
527+
return GetRef<Var>(var);
528+
} else {
529+
Var new_var;
530+
if (type_unchanged) {
531+
new_var = Var(var->vid, NullOpt, var->type_annotation, var->span);
532+
} else {
533+
new_var = Var(var->vid, NullOpt, new_type, var->span);
534+
}
535+
536+
if (shape_unchanged) {
444537
new_var->shape_ = var->shape_;
445-
this->var_remap_[var->vid] = new_var;
446-
return new_var;
538+
} else {
539+
new_var->shape_ = new_shape;
447540
}
541+
542+
this->var_remap_[var->vid] = new_var;
543+
return new_var;
448544
}
449-
return GetRef<Var>(var);
450545
}
451546

452547
Expr ExprMutator::VisitExpr(const Expr& expr) {

tests/python/relax/test_transform.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def test_fma_rewrite():
5353
assert structural_equal(gv0.shape, relax.ShapeExpr([m, n]))
5454

5555
# after rewrite
56-
passes = [relax.transform.FMARewrite()]
57-
seq = tvm.transform.Sequential(passes)
58-
new_mod = seq(mod)
56+
new_mod = relax.transform.FMARewrite()(mod)
5957
func = new_mod["main"]
6058
v1 = func.body.blocks[0].bindings[1].var
6159
s1 = func.body.blocks[0].bindings[1].value
@@ -69,6 +67,31 @@ def test_fma_rewrite():
6967
assert gv0 == v0
7068
assert type(func.body.blocks[0].bindings[1].var) == relax.Var
7169

70+
def test_visit_shape():
71+
@tvm.script.ir_module
72+
class TestVisitShape:
73+
@R.function
74+
def foo(x: Tensor[(m, n), "float32"]):
75+
gv0 = R.add(x, x)
76+
return gv0
77+
78+
mod = TestVisitShape
79+
80+
shape_expr = []
81+
def fvisit(e):
82+
if isinstance(e, relax.ShapeExpr):
83+
nonlocal shape_expr
84+
shape_expr.append(e)
85+
86+
relax.analysis.post_order_visit(mod["foo"], fvisit)
87+
88+
# should have visited ShapeExpr 3 times
89+
# the first time being visited is x.shape
90+
# the last two times are the call node's shape and gv0's shape
91+
assert len(shape_expr) == 3
92+
assert shape_expr[0] == mod["foo"].params[0].shape
93+
assert shape_expr[1] == shape_expr[2]
94+
7295

7396
def test_to_non_dataflow():
7497
@tvm.script.ir_module
@@ -312,6 +335,7 @@ def foo(x: Tensor[(m, n), "float32"]):
312335

313336
if __name__ == "__main__":
314337
test_fma_rewrite()
338+
test_visit_shape()
315339
test_to_non_dataflow()
316340
test_call_dps_rewrite()
317341
test_vm_memory_lower()

0 commit comments

Comments
 (0)