Skip to content

Commit 55ec7f4

Browse files
committed
Format and Buffer data structure (#1)
[SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> Fix AxisTree (#3) * fix axis tree * upd [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` [SparseTIR] Introduce SpIterVar (#6) * [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr [BugFix] Fix binary search & SpIterVar (#7) [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting [SparseTIR] Index Lowering (#8) * Add StmtFunctor/ExprFunctor for SparseBufferStore/Load * Add basic index lowering * Finish index lowering (maybe) * Address comments * Convert CRLF to LF Frontend update, demo scripts. (#10) * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * [SparseTIR] Introduce SpIterVar (#6) * [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr * [BugFix] Fix binary search & SpIterVar (#7) * [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting * upd * upd Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> [SparseTIR] SparseBlock on C++/Python side (#11) * Fix a bug in the last commit * SparseBlock on C++ & Python side [BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (#12) * Update `cord` and `pos` * Fix `idtype` * Formatting.. * Bug fix 1 * Move new special stmts * Parser for Axis and SpIterVar * Fix context_maintainer.py [SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (#13) * Enhance SparseBlock to have enough PrimFunc info * Remove `func_sparse_buffer_map_` * Don't print the map uh-huh [SparseTIR] Parser, Printer, Roundtrip (#14) * SparseBlock scope handler (part 1) * SparseBlock scope handler (part 2) * SparseBlock scope handler (part 3) * SparseBlock scope handler (fix 1) * Add SparseBufferLoad/Store on Python side * Parser for SparseBufferLoad/Store * Add SparseBlock to Python __init__ * StmtFunctor for SparseBlock * Ensure at least one dimension for SparseBuffer * Make `axis` field of SpIterVar mandatory * SparseBlock scope handler (fix 2) * Update Axis syntax by removing `name` parameter * Move to intrin.py * Add filed `from_sparse` to DenseFixedAxis * SparseTIR script printer * Roundtrip test * `update_symbol` bug fix * Fix attr visit in SparseBuffer * Define then compare in SparseBlock * Fix printer bug for SparseBuffer * Enable graph match for Axis and SparseBuffer * Complete HashReduce and EqualReduce for AxisTree and SparseBuffer * Fix typo * Rename test * Bug fix 1 * Bug fix 2 * Add more tests Move tests (#15) [SparseTIR] ReprPrinter for Axis and SpIterVar (#16) upd (#17) flatten (#18) ELL and BSR correctness test scripts (#19) [SparseTIR] SparseTIR Lowering (#20) * Fix a previous bug of sparse-fixed SpIterVar creation * Fix a previous bug in `GetDenseValue` * Refactor Collector and IndexTransformer * Construct block and loops * Fix a previous bug which rejects DV iters in collector * Update buffer map * Create root block * Fix bug of sparse-fixed SpIterVar creation * Fix bug on SpIterVar conversion (with refactor) * Fix bug when getting dependent SpIterVars * Fix bug on dependency map and index lowering * Full block read/write region * Test version 1 * Fix bug of loop order * Fix bug of batch-mm iterator ordering * Update PrimFunc args to use symbolic params * Fix bug of test "csr_element_wise" * Fix bug of index accumulation for sparse-fixed axis * Update correctness test * Test structural equality * Refactor and use Array fix nnz cols Add docstring for sparse tir lowering (#21) * add docstring * upd Add more examples part 1 (sddmm) (#22) * upd * upd * upd [SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (#23) * Test initialization * Fix a stupid bug of ReprPrinter * Add SparseBlockRV * Schedule: GetSparseBlock * Schedule: Reorder [SparseTIR][Schedule] GetSpIters (#24) remove hybrid script for successful compilation Add atomic intrinsic for output nonzero inference. (#25) * upd * upd Add "sparse" block attribute. (#26) Revert "remove hybrid script for successful compilation" This reverts commit eebd7c1. [SparseTIR] Hack `IsAffineBinding` check (#27) * [TensorIR][Schedule] Inherit block anotation upon creating new blocks * Fix SDDMM test * Hack IsAffineBinding for sparse blocks Axis Dependency Tree aware code-gen and bmm example (#28) * upd * upd * upd * upd * upd * upd * upd * upd * remove redundancy * fix * upd * upd Re-design Indices lowering (#29) * upd * upd * upd * upd * upd * init * format * fix * revise coding-style * format Complete indices lowering (#30) * upd * upd * upd * done * upd * passed test * upd Add more docstrings and depress warnings for new lowering algorithm. (#31) Refactor derived axis, frontend support of fusion. (#32) * upd * upd * fix Fatal bugfix and change the signature of DenseVariableAxis. (#33) Syntax simplification (#34) Change the order of generated blocks for block isolation. (#35) * upd * upd * upd Syntax of AttachAxis for BMM (#36) * upd * upd * upd [SparseTIR] Add "square sum" lowering test (#37) * Add square sum test * Remove pylint comment [BugFix] Fix offset caching in lowering (#38) * Hack compact dataflow check in a dirty way * Add two-K square sum test * Mark skipped tests * Fix offset saving in lowering Fusion syntax fix + SDDMM example. (#39) Some structure change on update offsets. (#40) [Refactor] SparseTIR Lowering (#41) * Take out methods in Scope * Refactor * Refactor "match" * Tweak scope contents * Refactor ViewIndexInAxis * Refactor Scope * SDDMM tests under implementation * Refactor block stack * Use Map for var_map * Extract NeedCreateNewBlock * Simplify SpIterVarToIterVar via GetIterExtent * Refactor NeedCreateNewBlock * Add docstring * Use "auto" correctly * Minor refactor and use some move Remove redundant analyzers (#42) Support indices lowering for attach and fuse. (#43) * upd * upd * upd Fix irregular BMM example. (#44) * upd * upd * upd * upd RGCN forward and butterfly pattern example. (#45) Fused SDDMM example. (#46) * upd * wip * fix Fix sparse reorder after refactor (#47) [Refactor] Refactor Unittest (#48) * upd * remove redundancy [Unittest] Correctness test for benchmarking scripts (#49) Bugfix and more test for axis fusion, new workload (#50) * upd * upd upd
1 parent 6931872 commit 55ec7f4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+6240
-23
lines changed

include/tvm/tir/builtin.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,21 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
494494
TVM_DLL const Op& tvm_warp_shuffle_down();
495495
TVM_DLL const Op& tvm_warp_activemask();
496496

497+
/*!
498+
* \brief Lower bound function for binary search.
499+
*/
500+
TVM_DLL const Op& tvm_lower_bound();
501+
502+
/*!
503+
* \brief Upper bound function for binary search.
504+
*/
505+
TVM_DLL const Op& tvm_upper_bound();
506+
507+
/*!
508+
* \brief Atomic add function.
509+
*/
510+
TVM_DLL const Op& tvm_atomic_add();
511+
497512
/*!
498513
* \brief Initialize the global barrier.
499514
* Call this at beginning of kernel that need global barrier.

include/tvm/tir/expr.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <tvm/runtime/container/string.h>
3535
#include <tvm/runtime/data_type.h>
3636
#include <tvm/tir/buffer.h>
37+
#include <tvm/tir/sparse.h>
3738
#include <tvm/tir/var.h>
3839

3940
#include <algorithm>
@@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr {
643644
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
644645
};
645646

647+
/*!
648+
* \brief Load value from the high dimension sparse buffer.
649+
*
650+
* \code
651+
*
652+
* value = buffer[i, j];
653+
*
654+
* \endcode
655+
* \sa SparseBufferStore
656+
*/
657+
class SparseBufferLoadNode : public PrimExprNode {
658+
public:
659+
/*! \brief The buffer to be loaded. */
660+
SparseBuffer buffer;
661+
/*! \brief The indices location to be loaded. */
662+
Array<PrimExpr> indices;
663+
664+
void VisitAttrs(AttrVisitor* v) {
665+
v->Visit("dtype", &(this->dtype));
666+
v->Visit("buffer", &buffer);
667+
v->Visit("indices", &indices);
668+
v->Visit("span", &span);
669+
}
670+
671+
bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const {
672+
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
673+
equal(indices, other->indices);
674+
}
675+
676+
void SHashReduce(SHashReducer hash_reduce) const {
677+
hash_reduce(dtype);
678+
hash_reduce(buffer);
679+
hash_reduce(indices);
680+
}
681+
682+
static constexpr const char* _type_key = "tir.SparseBufferLoad";
683+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode);
684+
};
685+
686+
/*!
687+
* \brief Managed reference to SparseBufferLoadNode.
688+
* \sa SparseBufferLoadNode
689+
*/
690+
class SparseBufferLoad : public PrimExpr {
691+
public:
692+
TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices,
693+
Span span = Span());
694+
695+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode);
696+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode);
697+
};
698+
646699
/*!
647700
* \brief Load value from the result produced by the producer.
648701
*

include/tvm/tir/expr_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
119119
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
120120
}
121121
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122+
virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122123
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123124
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124125
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -165,6 +166,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
165166
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
166167
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
167168
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
169+
IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode);
168170
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
169171
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
170172
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
@@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
217219
void VisitExpr_(const SizeVarNode* op) override;
218220
void VisitExpr_(const LoadNode* op) override;
219221
void VisitExpr_(const BufferLoadNode* op) override;
222+
void VisitExpr_(const SparseBufferLoadNode* op) override;
220223
void VisitExpr_(const ProducerLoadNode* op) override;
221224
void VisitExpr_(const LetNode* op) override;
222225
void VisitExpr_(const CallNode* op) override;
@@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
264267
PrimExpr VisitExpr_(const SizeVarNode* op) override;
265268
PrimExpr VisitExpr_(const LoadNode* op) override;
266269
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
270+
PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override;
267271
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
268272
PrimExpr VisitExpr_(const LetNode* op) override;
269273
PrimExpr VisitExpr_(const CallNode* op) override;

include/tvm/tir/op.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,38 @@ TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
820820
*/
821821
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
822822

823+
/*!
824+
* \brief Lower bound function for binary search
825+
* \param arr The buffer variable of the array to be looked up in
826+
* \param val The value to be looked up in the array
827+
* \param l The left boundary of the look-up range (inclusive)
828+
* \param r The right boundary of the look-up range (exclusive)
829+
* \param span The location of this operation in the source
830+
* \return The look-up result
831+
*/
832+
TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
833+
Span span = Span());
834+
835+
/*!
836+
* \brief Upper bound function for binary search
837+
* \param arr The buffer variable of the array to be looked up in
838+
* \param val The value to be looked up in the array
839+
* \param l The left boundary of the look-up range (inclusive)
840+
* \param r The right boundary of the look-up range (exclusive)
841+
* \param span The location of this operation in the source
842+
* \return The look-up result
843+
*/
844+
TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r,
845+
Span span = Span());
846+
847+
/*!
848+
* \brief Perform atomic add on ptr by val, and return the old value.
849+
* \param ptr The address to perform atomic add.
850+
* \param val The value to add.
851+
* \return The old result stored in ptr.
852+
*/
853+
TVM_DLL PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span = Span());
854+
823855
/*!
824856
* \brief Calculate trunc(x)
825857
* \param x The input expression.

include/tvm/tir/schedule/schedule.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <tvm/support/random_engine.h>
2323
#include <tvm/tir/schedule/state.h>
2424
#include <tvm/tir/schedule/trace.h>
25+
#include <tvm/tir/sparse.h>
2526

2627
namespace tvm {
2728
namespace tir {
@@ -85,6 +86,27 @@ using ExprRV = PrimExpr;
8586

8687
using ExprRVNode = PrimExprNode;
8788

89+
/**************** Random variable: SparseBlockRV ****************/
90+
91+
/*! \brief A random variable that evaluates to a TensorIR sparse block */
92+
class SparseBlockRVNode : public runtime::Object {
93+
public:
94+
void VisitAttrs(tvm::AttrVisitor* v) {}
95+
static constexpr const char* _type_key = "tir.SparseBlockRV";
96+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
97+
};
98+
99+
/*!
100+
* \brief Managed reference to SparseBlockRVNode
101+
* \sa SparseBlockRVNode
102+
*/
103+
class SparseBlockRV : public runtime::ObjectRef {
104+
public:
105+
/*! \brief Constructor */
106+
TVM_DLL SparseBlockRV();
107+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
108+
};
109+
88110
/**************** The Schedule class ****************/
89111

90112
class Schedule;
@@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
143165
* \return The corresponding expr
144166
*/
145167
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
168+
/*!
169+
* \brief Get the sparse block corresponding to the specific random variable
170+
* \param sp_block_rv The random variable to be looked up
171+
* \return SparseBlock The corresponding sparse block
172+
*/
173+
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
146174
/*!
147175
* \brief Get the block sref corresponding to the specific BlockRV
148176
* \param block_rv The BlockRV to be looked up
@@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object {
188216
* \param expr_rv The random variable to be removed
189217
*/
190218
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
219+
/*!
220+
* \brief Remove an sparse block random variable from the symbol table
221+
* \param sp_block_rv The random variable to be removed
222+
*/
223+
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;
191224

192225
public:
193226
/******** Schedule: Sampling ********/
@@ -524,6 +557,29 @@ class ScheduleNode : public runtime::Object {
524557
/******** Schedule: Misc ********/
525558
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
526559
virtual void EnterPostproc() = 0;
560+
/******** Schedule: SparseTIR schedules ********/
561+
/*!
562+
* \brief Retrieve a sparse block in a specific function with its name
563+
* \param name The name of the sparse block to be retrieved
564+
* \param func_name The name of the function
565+
* \return The sparse block retrieved
566+
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
567+
*/
568+
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
569+
/*!
570+
* \brief Retrieve the sparse iterators of a given sparse block
571+
* \param block_rv The block to be queried
572+
* \return The sparse iterators of the input sparse block
573+
*/
574+
virtual Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) = 0;
575+
/*!
576+
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
577+
* dependency.
578+
* \param block The block to be transformed
579+
* \param new_order The new order of the sparse iterators, whose length should equal to the number
580+
* of the input block's sparse iterators
581+
*/
582+
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
527583
};
528584

529585
/*!

include/tvm/tir/schedule/state.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ class ScheduleStateNode : public Object {
162162
* \return A boolean flag indicating if the block has quasi-affine bindings
163163
*/
164164
bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
165+
// (SparseTIR Hack) Always return true for sparse blocks.
166+
const auto* block = block_sref->StmtAs<BlockNode>();
167+
Optional<ObjectRef> sparse_attr = block != nullptr ? block->annotations.Get("sparse") : NullOpt;
168+
if (sparse_attr.defined() && sparse_attr.as<IntImmNode>()->value == 1) {
169+
return true;
170+
}
171+
165172
return GetBlockInfo(block_sref).affine_binding;
166173
}
167174
/*!

0 commit comments

Comments
 (0)