2929#include < tvm/node/functor.h>
3030#include < tvm/relax/block_builder.h>
3131#include < tvm/relax/expr.h>
32- #include < tvm/relay/adt.h>
33- #include < tvm/relay/expr.h>
34- #include < tvm/relay/function.h>
32+ #include < tvm/relax/struct_info.h>
33+ #include < tvm/relax/struct_info_functor.h>
3534#include < tvm/relay/op.h>
3635#include < tvm/tir/function.h>
3736
@@ -213,7 +212,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
213212 virtual void VisitBinding (const Binding& binding);
214213 // specific leaf level visitor functions
215214 virtual void VisitBinding_ (const VarBindingNode* binding);
216- virtual void VisitBinding_ (const MatchShapeNode * binding);
215+ virtual void VisitBinding_ (const MatchCastNode * binding);
217216 // second level dispatching based on binding value type.
218217 // these dispatching functions get called from first-level dispatch on VarBinding
219218 virtual void VisitBinding_ (const VarBindingNode* binding, const ConstantNode* val);
@@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
244243 * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
245244 */
246245 virtual void VisitVarDef (const Var& var);
246+
247+ /* !
248+ * \brief Visit struct_info may recursively contain Expr/PrimExpr.
249+ *
250+ * By default, this function recurse into struct info such as
251+ * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
252+ * accordingly. It does not recurse into FunctionStructInfo as it does
253+ * not contain Expr defined in the current scope.
254+ *
255+ * Pass writers can overload this function to change to other behaviors.
256+ * For example, if we are not interested in Expr in StructInfo, we can
257+ * override this function by a no-op.
258+ *
259+ * \param struct_info Input struct info field.
260+ */
261+ virtual void VisitExprDepStructInfoField (const StructInfo& struct_info);
262+
247263 // specific leaf level visitor functions
248264 virtual void VisitVarDef_ (const VarNode* var);
249265 virtual void VisitVarDef_ (const DataflowVarNode* var);
@@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
258274 tvm::NodeFunctor<void (const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
259275 // initialize the vtable.
260276 static VisitBindingVTable InitVisitBindingVTable ();
277+ /* !
278+ * \brief Private internal struct info field visitor.
279+ *
280+ * Support default visiting of struct info field and recursive into
281+ * their Expr fields.
282+ *
283+ * We use component instead of sub-classing so there can be other
284+ * joint inheritance between ExprVisitor and StructInfoVisitor.
285+ */
286+ class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
287+ public:
288+ explicit DefaultStructInfoFieldVisitor (ExprVisitor* parent);
289+
290+ // Override defaults in struct info visitor.
291+ void VisitStructInfoExprField (const Expr& expr) final ;
292+ void VisitStructInfoExprField (const PrimExpr& expr) final ;
293+ void VisitStructInfo_ (const FuncStructInfoNode* op) final ;
294+
295+ private:
296+ ExprVisitor* parent_;
297+ };
298+ // This visitor is not visible to child classes and only
299+ // used to supportd default visiting behavior.
300+ DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this };
261301};
262302
263303void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
@@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
309349 * Can be overloaded to transform the shape expressions.
310350 */
311351 virtual PrimExpr VisitPrimExpr (const PrimExpr& expr);
352+
353+ /* !
354+ * \brief Visit struct_info that may recursively contain Expr/PrimExpr.
355+ *
356+ * By default, this function recurse into struct info such as
357+ * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
358+ * accordingly. It does not recurse into FunctionStructInfo as it does
359+ * not contain Expr defined in the current scope.
360+ *
361+ * Pass writers can overload this function to change to other behaviors.
362+ * For example, if in Expr in StructInfo won't change, we can
363+ * override this function by an identity function.
364+ *
365+ * \param struct_info Input struct info field.
366+ * \return The updated struct info.
367+ */
368+ virtual StructInfo VisitExprDepStructInfoField (const StructInfo& struct_info);
369+
370+ protected:
371+ /* !
372+ * \brief Check whether VisitExprDepStructInfoField change struct_info.
373+ * \return Whether struct info changed.
374+ * \note This function is used by mutator implementations to check if
375+ * previous Expr update will trigger a change in struct_info.
376+ * If change is detected, the implementation can generate a fresh
377+ * node without struct_info, and trigger normalizer to re-derive.
378+ */
379+ bool VisitAndCheckStructInfoFieldUnchanged (const ObjectRef& struct_info) {
380+ if (const StructInfoNode* sinfo = struct_info.as <StructInfoNode>()) {
381+ return this ->VisitExprDepStructInfoField (GetRef<StructInfo>(sinfo)).same_as (struct_info);
382+ } else {
383+ return true ;
384+ }
385+ }
386+
387+ private:
388+ /* !
389+ * \brief Private internal struct info field visitor to support
390+ * Default visiting of struct info field and recursive into their Expr fields.
391+ *
392+ * We use component instead of sub-classing so there can be other
393+ * joint inheritance between ExprMutator and StructInfoMutator.
394+ */
395+ class DefaultStructInfoFieldMutator : public StructInfoMutator {
396+ public:
397+ explicit DefaultStructInfoFieldMutator (ExprMutatorBase* parent);
398+
399+ // Override defaults in struct info visitor.
400+ Expr VisitStructInfoExprField (const Expr& expr) final ;
401+ PrimExpr VisitStructInfoExprField (const PrimExpr& expr) final ;
402+ StructInfo VisitStructInfo_ (const FuncStructInfoNode* op) final ;
403+
404+ private:
405+ ExprMutatorBase* parent_;
406+ };
407+ // This visitor is not visible to child classes and only
408+ // used to supportd default visiting behavior.
409+ DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this };
312410};
313411
314412/* !
@@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase {
324422
325423 ExprMutator (Optional<IRModule> mod = NullOpt) { builder_ = BlockBuilder::Create (mod); }
326424 Expr VisitExpr (const Expr& expr) override ;
327- Expr VisitExpr_ (const TupleNode* op) override ;
328425 Expr VisitExpr_ (const VarNode* op) override ;
329426 Expr VisitExpr_ (const DataflowVarNode* op) override ;
330427 Expr VisitExpr_ (const FunctionNode* op) override ;
@@ -338,7 +435,7 @@ class ExprMutator : public ExprMutatorBase {
338435 virtual void VisitBinding (const Binding& binding);
339436 // specific leaf level visitor functions
340437 virtual void VisitBinding_ (const VarBindingNode* binding);
341- virtual void VisitBinding_ (const MatchShapeNode * binding);
438+ virtual void VisitBinding_ (const MatchCastNode * binding);
342439 // second level dispatching based on binding value type.
343440 // these dispatching functions get called from first-level dispatch on VarBinding
344441 virtual void VisitBinding_ (const VarBindingNode* binding, const ConstantNode* val);
@@ -484,9 +581,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
484581 /* ! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
485582 * function. */
486583 PackedFunc f_visit_var_binding_{nullptr };
487- /* ! \brief The packed function to the `VisitBinding_(const MatchShapeNode * binding)`
584+ /* ! \brief The packed function to the `VisitBinding_(const MatchCastNode * binding)`
488585 * function. */
489- PackedFunc f_visit_match_shape_ {nullptr };
586+ PackedFunc f_visit_match_cast_ {nullptr };
490587 /* ! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
491588 * function. */
492589 PackedFunc f_visit_binding_block{nullptr };
@@ -523,8 +620,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
523620 void VisitBinding_ (const VarBindingNode* binding)
524621 PY_EXPR_VISITOR_DEFAULT(GetRef<VarBinding>(binding), f_visit_var_binding_,
525622 ExprVisitor::VisitBinding_(binding));
526- void VisitBinding_ (const MatchShapeNode * binding)
527- PY_EXPR_VISITOR_DEFAULT(GetRef<MatchShape >(binding), f_visit_match_shape_ ,
623+ void VisitBinding_ (const MatchCastNode * binding)
624+ PY_EXPR_VISITOR_DEFAULT(GetRef<MatchCast >(binding), f_visit_match_cast_ ,
528625 ExprVisitor::VisitBinding_(binding));
529626
530627 void VisitBindingBlock (const BindingBlock& block)
@@ -602,7 +699,7 @@ class PyExprVisitor : public ObjectRef {
602699 * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
603700 * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
604701 * binding)`.
605- * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode *
702+ * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode *
606703 * binding)`.
607704 * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
608705 * block)`.
@@ -624,7 +721,7 @@ class PyExprVisitor : public ObjectRef {
624721 PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_,
625722 PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_,
626723 PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
627- PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_ ,
724+ PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_ ,
628725 PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
629726 PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
630727 PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
@@ -649,7 +746,7 @@ class PyExprVisitor : public ObjectRef {
649746 n->f_visit_op_ = f_visit_op_;
650747 n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
651748 n->f_visit_var_binding_ = f_visit_var_binding_;
652- n->f_visit_match_shape_ = f_visit_match_shape_ ;
749+ n->f_visit_match_cast_ = f_visit_match_cast_ ;
653750 n->f_visit_binding_block_ = f_visit_binding_block_;
654751 n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
655752 n->f_visit_var_def_ = f_visit_var_def_;
@@ -702,9 +799,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
702799 /* ! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
703800 * function. */
704801 PackedFunc f_visit_var_binding_{nullptr };
705- /* ! \brief The packed function to the `VisitBinding_(const MatchShapeNode * binding)`
802+ /* ! \brief The packed function to the `VisitBinding_(const MatchCastNode * binding)`
706803 * function. */
707- PackedFunc f_visit_match_shape_ {nullptr };
804+ PackedFunc f_visit_match_cast_ {nullptr };
708805 /* ! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
709806 * function. */
710807 PackedFunc f_visit_binding_block{nullptr };
@@ -748,9 +845,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
748845 ExprMutator::VisitBinding_ (binding);
749846 }
750847
751- void VisitBinding_ (const MatchShapeNode * binding) {
752- if (f_visit_match_shape_ != nullptr )
753- f_visit_match_shape_ (GetRef<MatchShape >(binding));
848+ void VisitBinding_ (const MatchCastNode * binding) {
849+ if (f_visit_match_cast_ != nullptr )
850+ f_visit_match_cast_ (GetRef<MatchCast >(binding));
754851 else
755852 ExprMutator::VisitBinding_ (binding);
756853 }
@@ -866,7 +963,7 @@ class PyExprMutator : public ObjectRef {
866963 * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
867964 * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
868965 * binding)`.
869- * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode *
966+ * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode *
870967 * binding)`.
871968 * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
872969 * block)`.
@@ -889,7 +986,7 @@ class PyExprMutator : public ObjectRef {
889986 PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_,
890987 PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_,
891988 PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
892- PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_ ,
989+ PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_ ,
893990 PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
894991 PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
895992 PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
@@ -911,7 +1008,7 @@ class PyExprMutator : public ObjectRef {
9111008 n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
9121009 n->f_visit_binding = f_visit_binding;
9131010 n->f_visit_var_binding_ = f_visit_var_binding_;
914- n->f_visit_match_shape_ = f_visit_match_shape_ ;
1011+ n->f_visit_match_cast_ = f_visit_match_cast_ ;
9151012 n->f_visit_binding_block = f_visit_binding_block;
9161013 n->f_visit_binding_block_ = f_visit_binding_block_;
9171014 n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
0 commit comments