Skip to content

Commit 330c9fa

Browse files
tqchenSiyuan Feng
authored andcommitted
[REFACTOR] StructInfo M3: MatchShape=>MatchCast (apache#323)
* Introduce match cast, and code changes along * add match_cast parser support (#9) * Match cast support for VMShapeLower CanonicalizeBinding * Remove `match_shape` (#12) * Refactor ExprVisitor/Mutator to consider Expr in StructInfo. Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
1 parent baea09a commit 330c9fa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+789
-651
lines changed

include/tvm/relax/block_builder.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,13 @@ class BlockBuilderNode : public Object {
160160
virtual Var Emit(Expr expr, String name_hint = "") = 0;
161161

162162
/*!
163-
* \brief Emit a MatchShape.
164-
* \param value The value of the MatchShape to be emitted.
165-
* \param pattern The pattern of the MatchShape to be emitted.
163+
* \brief Emit a MatchCast.
164+
* \param value The input value.
165+
* \param struct_info The struct info to be matched.
166166
* \param name_hint Name hint for the bound variable.
167-
* \return The variable bound to the MatchShape.
167+
* \return The variable bound to the MatchCast.
168168
*/
169-
virtual Var EmitMatchShape(Expr value, Array<PrimExpr> pattern, String name_hint = "") = 0;
169+
virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0;
170170

171171
/*!
172172
* \brief Generate an output for the current dataflow block.

include/tvm/relax/expr.h

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -531,12 +531,10 @@ class Constant : public Expr {
531531
/*! \brief The base class of a variable binding in Relax. */
532532
class BindingNode : public Object {
533533
public:
534+
/*! \brief The return variable to bound to. */
535+
Var var;
534536
mutable Span span;
535537

536-
void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
537-
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
538-
void SHashReduce(SHashReducer hash_reduce) const {}
539-
540538
static constexpr const char* _type_key = "relax.expr.Binding";
541539
static constexpr const bool _type_has_method_sequal_reduce = true;
542540
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -555,51 +553,61 @@ class Binding : public ObjectRef {
555553
using ContainerType = BindingNode;
556554
};
557555

558-
/*! \brief Symbolic shape match, binds the variable of the lhs with the rhs. */
559-
class MatchShape;
560-
class MatchShapeNode : public BindingNode {
556+
/*!
557+
* \brief Runtime-match the value to the struct info.
558+
*
559+
* This operation does runtime check, populates the un-defined symbolic shape vars
560+
* and vars in struct_info in first occurance, and insert equality assertions in
561+
* other cases.
562+
*/
563+
class MatchCastNode : public BindingNode {
561564
public:
565+
/*! \brief The input value to match cast. */
562566
Expr value;
563-
Array<PrimExpr> pattern;
564-
Var var;
567+
/*! \brief The struct info pattern to match to. */
568+
StructInfo struct_info;
565569

566570
void VisitAttrs(AttrVisitor* v) {
567-
v->Visit("value", &value);
568-
v->Visit("pattern", &pattern);
569571
v->Visit("var", &var);
572+
v->Visit("value", &value);
573+
v->Visit("struct_info", &struct_info);
570574
v->Visit("span", &span);
571575
}
572576

573-
bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
577+
bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
574578
// NOTE: pattern can contain ShapeExpr which defines the vars
575-
return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) &&
576-
equal.DefEqual(var, other->var);
579+
return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) &&
580+
equal(value, other->value);
577581
}
578582

579583
void SHashReduce(SHashReducer hash_reduce) const {
580584
// NOTE: pattern can contain ShapeExpr which defines the vars
581-
hash_reduce(value);
582-
hash_reduce.DefHash(pattern);
583585
hash_reduce.DefHash(var);
586+
hash_reduce.DefHash(struct_info);
587+
hash_reduce(value);
584588
}
585589

586-
static constexpr const char* _type_key = "relax.expr.MatchShape";
590+
static constexpr const char* _type_key = "relax.expr.MatchCast";
587591
static constexpr const bool _type_has_method_sequal_reduce = true;
588592
static constexpr const bool _type_has_method_shash_reduce = true;
589-
TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode);
593+
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode);
590594
};
591595

592-
class MatchShape : public Binding {
596+
/*!
597+
* \brief Managed reference to MatchCastNode.
598+
* \sa MatchCastNode
599+
*/
600+
class MatchCast : public Binding {
593601
public:
594-
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern, Var var, Span span = Span());
595-
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
596-
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
602+
TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());
603+
604+
TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode);
605+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode);
597606
};
598607

599-
class VarBinding;
600608
class VarBindingNode : public BindingNode {
601609
public:
602-
Var var;
610+
/*! \brief The binding value. */
603611
Expr value;
604612

605613
void VisitAttrs(AttrVisitor* v) {
@@ -628,8 +636,6 @@ class VarBinding : public Binding {
628636
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode);
629637
};
630638

631-
class BindingBlock;
632-
633639
class BindingBlockNode : public Object {
634640
public:
635641
mutable Span span;

include/tvm/relax/expr_functor.h

Lines changed: 118 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@
2929
#include <tvm/node/functor.h>
3030
#include <tvm/relax/block_builder.h>
3131
#include <tvm/relax/expr.h>
32-
#include <tvm/relay/adt.h>
33-
#include <tvm/relay/expr.h>
34-
#include <tvm/relay/function.h>
32+
#include <tvm/relax/struct_info.h>
33+
#include <tvm/relax/struct_info_functor.h>
3534
#include <tvm/relay/op.h>
3635
#include <tvm/tir/function.h>
3736

@@ -213,7 +212,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
213212
virtual void VisitBinding(const Binding& binding);
214213
// specific leaf level visitor functions
215214
virtual void VisitBinding_(const VarBindingNode* binding);
216-
virtual void VisitBinding_(const MatchShapeNode* binding);
215+
virtual void VisitBinding_(const MatchCastNode* binding);
217216
// second level dispatching based on binding value type.
218217
// these dispatching functions get called from first-level dispatch on VarBinding
219218
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
@@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
244243
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
245244
*/
246245
virtual void VisitVarDef(const Var& var);
246+
247+
/*!
248+
* \brief Visit struct_info may recursively contain Expr/PrimExpr.
249+
*
250+
* By default, this function recurse into struct info such as
251+
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
252+
* accordingly. It does not recurse into FunctionStructInfo as it does
253+
* not contain Expr defined in the current scope.
254+
*
255+
* Pass writers can overload this function to change to other behaviors.
256+
* For example, if we are not interested in Expr in StructInfo, we can
257+
* override this function by a no-op.
258+
*
259+
* \param struct_info Input struct info field.
260+
*/
261+
virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);
262+
247263
// specific leaf level visitor functions
248264
virtual void VisitVarDef_(const VarNode* var);
249265
virtual void VisitVarDef_(const DataflowVarNode* var);
@@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
258274
tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
259275
// initialize the vtable.
260276
static VisitBindingVTable InitVisitBindingVTable();
277+
/*!
278+
* \brief Private internal struct info field visitor.
279+
*
280+
* Support default visiting of struct info field and recursive into
281+
* their Expr fields.
282+
*
283+
* We use component instead of sub-classing so there can be other
284+
* joint inheritance between ExprVisitor and StructInfoVisitor.
285+
*/
286+
class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
287+
public:
288+
explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);
289+
290+
// Override defaults in struct info visitor.
291+
void VisitStructInfoExprField(const Expr& expr) final;
292+
void VisitStructInfoExprField(const PrimExpr& expr) final;
293+
void VisitStructInfo_(const FuncStructInfoNode* op) final;
294+
295+
private:
296+
ExprVisitor* parent_;
297+
};
298+
// This visitor is not visible to child classes and only
299+
// used to supportd default visiting behavior.
300+
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
261301
};
262302

263303
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
@@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
309349
* Can be overloaded to transform the shape expressions.
310350
*/
311351
virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);
352+
353+
/*!
354+
* \brief Visit struct_info that may recursively contain Expr/PrimExpr.
355+
*
356+
* By default, this function recurse into struct info such as
357+
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
358+
* accordingly. It does not recurse into FunctionStructInfo as it does
359+
* not contain Expr defined in the current scope.
360+
*
361+
* Pass writers can overload this function to change to other behaviors.
362+
* For example, if in Expr in StructInfo won't change, we can
363+
* override this function by an identity function.
364+
*
365+
* \param struct_info Input struct info field.
366+
* \return The updated struct info.
367+
*/
368+
virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);
369+
370+
protected:
371+
/*!
372+
* \brief Check whether VisitExprDepStructInfoField change struct_info.
373+
* \return Whether struct info changed.
374+
* \note This function is used by mutator implementations to check if
375+
* previous Expr update will trigger a change in struct_info.
376+
* If change is detected, the implementation can generate a fresh
377+
* node without struct_info, and trigger normalizer to re-derive.
378+
*/
379+
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) {
380+
if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
381+
return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
382+
} else {
383+
return true;
384+
}
385+
}
386+
387+
private:
388+
/*!
389+
* \brief Private internal struct info field visitor to support
390+
* Default visiting of struct info field and recursive into their Expr fields.
391+
*
392+
* We use component instead of sub-classing so there can be other
393+
* joint inheritance between ExprMutator and StructInfoMutator.
394+
*/
395+
class DefaultStructInfoFieldMutator : public StructInfoMutator {
396+
public:
397+
explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);
398+
399+
// Override defaults in struct info visitor.
400+
Expr VisitStructInfoExprField(const Expr& expr) final;
401+
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
402+
StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;
403+
404+
private:
405+
ExprMutatorBase* parent_;
406+
};
407+
// This visitor is not visible to child classes and only
408+
// used to supportd default visiting behavior.
409+
DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
312410
};
313411

314412
/*!
@@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase {
324422

325423
ExprMutator(Optional<IRModule> mod = NullOpt) { builder_ = BlockBuilder::Create(mod); }
326424
Expr VisitExpr(const Expr& expr) override;
327-
Expr VisitExpr_(const TupleNode* op) override;
328425
Expr VisitExpr_(const VarNode* op) override;
329426
Expr VisitExpr_(const DataflowVarNode* op) override;
330427
Expr VisitExpr_(const FunctionNode* op) override;
@@ -338,7 +435,7 @@ class ExprMutator : public ExprMutatorBase {
338435
virtual void VisitBinding(const Binding& binding);
339436
// specific leaf level visitor functions
340437
virtual void VisitBinding_(const VarBindingNode* binding);
341-
virtual void VisitBinding_(const MatchShapeNode* binding);
438+
virtual void VisitBinding_(const MatchCastNode* binding);
342439
// second level dispatching based on binding value type.
343440
// these dispatching functions get called from first-level dispatch on VarBinding
344441
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
@@ -484,9 +581,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
484581
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
485582
* function. */
486583
PackedFunc f_visit_var_binding_{nullptr};
487-
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
584+
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
488585
* function. */
489-
PackedFunc f_visit_match_shape_{nullptr};
586+
PackedFunc f_visit_match_cast_{nullptr};
490587
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
491588
* function. */
492589
PackedFunc f_visit_binding_block{nullptr};
@@ -523,8 +620,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
523620
void VisitBinding_(const VarBindingNode* binding)
524621
PY_EXPR_VISITOR_DEFAULT(GetRef<VarBinding>(binding), f_visit_var_binding_,
525622
ExprVisitor::VisitBinding_(binding));
526-
void VisitBinding_(const MatchShapeNode* binding)
527-
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchShape>(binding), f_visit_match_shape_,
623+
void VisitBinding_(const MatchCastNode* binding)
624+
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchCast>(binding), f_visit_match_cast_,
528625
ExprVisitor::VisitBinding_(binding));
529626

530627
void VisitBindingBlock(const BindingBlock& block)
@@ -602,7 +699,7 @@ class PyExprVisitor : public ObjectRef {
602699
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
603700
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
604701
* binding)`.
605-
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
702+
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
606703
* binding)`.
607704
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
608705
* block)`.
@@ -624,7 +721,7 @@ class PyExprVisitor : public ObjectRef {
624721
PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_,
625722
PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_,
626723
PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
627-
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
724+
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
628725
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
629726
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
630727
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
@@ -649,7 +746,7 @@ class PyExprVisitor : public ObjectRef {
649746
n->f_visit_op_ = f_visit_op_;
650747
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
651748
n->f_visit_var_binding_ = f_visit_var_binding_;
652-
n->f_visit_match_shape_ = f_visit_match_shape_;
749+
n->f_visit_match_cast_ = f_visit_match_cast_;
653750
n->f_visit_binding_block_ = f_visit_binding_block_;
654751
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
655752
n->f_visit_var_def_ = f_visit_var_def_;
@@ -702,9 +799,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
702799
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
703800
* function. */
704801
PackedFunc f_visit_var_binding_{nullptr};
705-
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
802+
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
706803
* function. */
707-
PackedFunc f_visit_match_shape_{nullptr};
804+
PackedFunc f_visit_match_cast_{nullptr};
708805
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
709806
* function. */
710807
PackedFunc f_visit_binding_block{nullptr};
@@ -748,9 +845,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
748845
ExprMutator::VisitBinding_(binding);
749846
}
750847

751-
void VisitBinding_(const MatchShapeNode* binding) {
752-
if (f_visit_match_shape_ != nullptr)
753-
f_visit_match_shape_(GetRef<MatchShape>(binding));
848+
void VisitBinding_(const MatchCastNode* binding) {
849+
if (f_visit_match_cast_ != nullptr)
850+
f_visit_match_cast_(GetRef<MatchCast>(binding));
754851
else
755852
ExprMutator::VisitBinding_(binding);
756853
}
@@ -866,7 +963,7 @@ class PyExprMutator : public ObjectRef {
866963
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
867964
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
868965
* binding)`.
869-
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
966+
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
870967
* binding)`.
871968
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
872969
* block)`.
@@ -889,7 +986,7 @@ class PyExprMutator : public ObjectRef {
889986
PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_,
890987
PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_,
891988
PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
892-
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
989+
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
893990
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
894991
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
895992
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
@@ -911,7 +1008,7 @@ class PyExprMutator : public ObjectRef {
9111008
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
9121009
n->f_visit_binding = f_visit_binding;
9131010
n->f_visit_var_binding_ = f_visit_var_binding_;
914-
n->f_visit_match_shape_ = f_visit_match_shape_;
1011+
n->f_visit_match_cast_ = f_visit_match_cast_;
9151012
n->f_visit_binding_block = f_visit_binding_block;
9161013
n->f_visit_binding_block_ = f_visit_binding_block_;
9171014
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;

0 commit comments

Comments
 (0)