3333namespace tvm {
3434namespace relax {
3535
36- void ExprVisitor::VisitExpr_ (const ConstantNode* op) { this ->VisitSpan (op->span ); }
36+ void ExprVisitor::VisitExpr_ (const ConstantNode* op) {
37+ this ->VisitSpan (op->span );
38+
39+ if (op->shape_ ) {
40+ this ->VisitExpr (Downcast<Expr>(op->shape_ .value ()));
41+ }
42+ }
3743
3844void ExprVisitor::VisitExpr_ (const GlobalVarNode* op) { this ->VisitSpan (op->span ); }
3945
@@ -42,20 +48,20 @@ void ExprVisitor::VisitExpr_(const TupleNode* op) {
4248 for (Expr field : op->fields ) {
4349 this ->VisitExpr (field);
4450 }
51+
52+ if (op->shape_ ) {
53+ this ->VisitExpr (Downcast<Expr>(op->shape_ .value ()));
54+ }
4555}
4656
57+ // Visit the use-site of a defined Var
4758void ExprVisitor::VisitExpr_ (const VarNode* op) {
4859 this ->VisitSpan (op->span );
49- if (op->type_annotation .defined ()) {
50- this ->VisitType (op->type_annotation .value ());
51- }
5260}
5361
62+ // Visit the use-site of a defined DataflowVar
5463void ExprVisitor::VisitExpr_ (const DataflowVarNode* op) {
5564 this ->VisitSpan (op->span );
56- if (op->type_annotation .defined ()) {
57- this ->VisitType (op->type_annotation .value ());
58- }
5965}
6066
6167void ExprVisitor::VisitExpr_ (const FunctionNode* op) {
@@ -78,6 +84,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) {
7884 for (Expr arg : op->args ) {
7985 this ->VisitExpr (arg);
8086 }
87+
88+ if (op->shape_ ) {
89+ this ->VisitExpr (Downcast<Expr>(op->shape_ .value ()));
90+ }
8191}
8292
8393void ExprVisitor::VisitExpr_ (const IfNode* op) {
@@ -142,19 +152,25 @@ void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
142152 if (var->type_annotation .defined ()) {
143153 this ->VisitType (var->type_annotation .value ());
144154 }
155+
156+ if (var->shape_ ) {
157+ this ->VisitExpr (Downcast<Expr>(var->shape_ .value ()));
158+ }
145159}
146160
147161void ExprVisitor::VisitVarDef_ (const VarNode* var) {
148162 this ->VisitSpan (var->span );
149163 if (var->type_annotation .defined ()) {
150164 this ->VisitType (var->type_annotation .value ());
151165 }
152- }
153166
154- void ExprVisitor::VisitExpr (const Expr& expr) {
155- ExprFunctor::VisitExpr (expr);
167+ if (var->shape_ ) {
168+ this ->VisitExpr (Downcast<Expr>(var->shape_ .value ()));
169+ }
156170}
157171
172+ void ExprVisitor::VisitExpr (const Expr& expr) { ExprFunctor::VisitExpr (expr); }
173+
158174void ExprVisitor::VisitBinding (const Binding& binding) {
159175 if (const auto * node = binding.as <VarBindingNode>()) {
160176 VisitBinding_ (node);
@@ -209,23 +225,48 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex
209225// ==================
210226// ExprMutator
211227
212- Expr ExprMutator::VisitExpr_ (const ConstantNode* op) { return GetRef<Expr>(op); }
228+ Expr ExprMutator::VisitExpr_ (const ConstantNode* op) {
229+ Expr new_shape;
230+ bool unchanged = true ;
231+ if (op->shape_ ) {
232+ new_shape = this ->VisitExpr (Downcast<Expr>(op->shape_ .value ()));
233+ if (!new_shape.same_as (op->shape_ )) {
234+ unchanged = false ;
235+ }
236+ }
237+
238+ if (unchanged) {
239+ return GetRef<Expr>(op);
240+ } else {
241+ Expr new_constant = Constant (op->data , op->span );
242+ new_constant->shape_ = new_shape;
243+ return new_constant;
244+ }
245+ }
213246
214247Expr ExprMutator::VisitExpr_ (const GlobalVarNode* op) { return GetRef<Expr>(op); }
215248
216249Expr ExprMutator::VisitExpr_ (const TupleNode* op) {
250+ bool unchanged = true ;
217251 tvm::Array<Expr> fields;
218- bool all_fields_unchanged = true ;
219252 for (Expr field : op->fields ) {
220253 Expr new_field = this ->VisitExpr (field);
221254 fields.push_back (new_field);
222- all_fields_unchanged &= new_field.same_as (field);
255+ unchanged &= new_field.same_as (field);
256+ }
257+
258+ Expr new_shape;
259+ if (op->shape_ ) {
260+ new_shape = this ->VisitExpr (Downcast<Expr>(op->shape_ .value ()));
261+ unchanged &= new_shape.same_as (op->shape_ );
223262 }
224263
225- if (all_fields_unchanged ) {
264+ if (unchanged ) {
226265 return GetRef<Expr>(op);
227266 } else {
228- return Tuple (fields, op->span );
267+ Expr new_tuple = Tuple (fields, op->span );
268+ new_tuple->shape_ = new_shape;
269+ return new_tuple;
229270 }
230271}
231272
@@ -288,10 +329,18 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
288329 unchanged &= new_arg.same_as (arg);
289330 }
290331
332+ Expr new_shape;
333+ if (call_node->shape_ ) {
334+ new_shape = this ->VisitExpr (Downcast<Expr>(call_node->shape_ .value ()));
335+ unchanged &= new_shape.same_as (call_node->shape_ );
336+ }
337+
291338 if (unchanged) {
292339 return GetRef<Expr>(call_node);
293340 } else {
294- return Call (new_op, call_args, call_node->attrs , ty_args, call_node->span );
341+ Expr new_call = Call (new_op, call_args, call_node->attrs , ty_args, call_node->span );
342+ new_call->shape_ = new_shape;
343+ return new_call;
295344 }
296345}
297346
@@ -424,29 +473,75 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) {
424473}
425474
426475Var ExprMutator::VisitVarDef_ (const DataflowVarNode* var) {
476+ bool type_unchanged = true ;
477+ Type new_type;
427478 if (var->type_annotation .defined ()) {
428- Type type = this ->VisitType (var->type_annotation .value ());
429- if (!var->type_annotation .same_as (type)) {
430- Var new_var = DataflowVar (var->vid , NullOpt, type, var->span );
479+ new_type = this ->VisitType (var->type_annotation .value ());
480+ type_unchanged &= new_type.same_as (var->type_annotation );
481+ }
482+
483+ bool shape_unchanged = true ;
484+ Expr new_shape;
485+ if (var->shape_ ) {
486+ new_shape = this ->VisitExpr (Downcast<Expr>(var->shape_ .value ()));
487+ shape_unchanged &= new_shape.same_as (var->shape_ );
488+ }
489+
490+ if (type_unchanged && shape_unchanged) {
491+ return GetRef<Var>(var);
492+ } else {
493+ Var new_var;
494+ if (type_unchanged) {
495+ new_var = DataflowVar (var->vid , NullOpt, var->type_annotation , var->span );
496+ } else {
497+ new_var = DataflowVar (var->vid , NullOpt, new_type, var->span );
498+ }
499+
500+ if (shape_unchanged) {
431501 new_var->shape_ = var->shape_ ;
432- this -> var_remap_ [var-> vid ] = new_var;
433- return new_var;
502+ } else {
503+ new_var-> shape_ = new_shape ;
434504 }
505+
506+ this ->var_remap_ [var->vid ] = new_var;
507+ return new_var;
435508 }
436- return GetRef<Var>(var);
437509}
438510
439511Var ExprMutator::VisitVarDef_ (const VarNode* var) {
512+ bool type_unchanged = true ;
513+ Type new_type;
440514 if (var->type_annotation .defined ()) {
441- Type type = this ->VisitType (var->type_annotation .value ());
442- if (!var->type_annotation .same_as (type)) {
443- Var new_var = Var (var->vid , NullOpt, type, var->span );
515+ new_type = this ->VisitType (var->type_annotation .value ());
516+ type_unchanged &= new_type.same_as (var->type_annotation );
517+ }
518+
519+ bool shape_unchanged = true ;
520+ Expr new_shape;
521+ if (var->shape_ ) {
522+ new_shape = this ->VisitExpr (Downcast<Expr>(var->shape_ .value ()));
523+ shape_unchanged &= new_shape.same_as (var->shape_ );
524+ }
525+
526+ if (type_unchanged && shape_unchanged) {
527+ return GetRef<Var>(var);
528+ } else {
529+ Var new_var;
530+ if (type_unchanged) {
531+ new_var = Var (var->vid , NullOpt, var->type_annotation , var->span );
532+ } else {
533+ new_var = Var (var->vid , NullOpt, new_type, var->span );
534+ }
535+
536+ if (shape_unchanged) {
444537 new_var->shape_ = var->shape_ ;
445- this -> var_remap_ [var-> vid ] = new_var;
446- return new_var;
538+ } else {
539+ new_var-> shape_ = new_shape ;
447540 }
541+
542+ this ->var_remap_ [var->vid ] = new_var;
543+ return new_var;
448544 }
449- return GetRef<Var>(var);
450545}
451546
452547Expr ExprMutator::VisitExpr (const Expr& expr) {
0 commit comments