|
34 | 34 | #include <tvm/runtime/container/string.h> |
35 | 35 | #include <tvm/runtime/data_type.h> |
36 | 36 | #include <tvm/tir/buffer.h> |
| 37 | +#include <tvm/tir/sparse.h> |
37 | 38 | #include <tvm/tir/var.h> |
38 | 39 |
|
39 | 40 | #include <algorithm> |
@@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr { |
643 | 644 | TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); |
644 | 645 | }; |
645 | 646 |
|
| 647 | +/*! |
| 648 | + * \brief Load value from the high dimension sparse buffer. |
| 649 | + * |
| 650 | + * \code |
| 651 | + * |
| 652 | + * value = buffer[i, j]; |
| 653 | + * |
| 654 | + * \endcode |
| 655 | + * \sa SparseBufferStore |
| 656 | + */ |
| 657 | +class SparseBufferLoadNode : public PrimExprNode { |
| 658 | + public: |
| 659 | + /*! \brief The buffer variable. */ |
| 660 | + SparseBuffer buffer; |
| 661 | + /*! \brief The indices location to be loaded. */ |
| 662 | + Array<PrimExpr> indices; |
| 663 | + |
| 664 | + void VisitAttrs(AttrVisitor* v) { |
| 665 | + v->Visit("dtype", &(this->dtype)); |
| 666 | + v->Visit("buffer", &buffer); |
| 667 | + v->Visit("indices", &indices); |
| 668 | + v->Visit("span", &span); |
| 669 | + } |
| 670 | + |
| 671 | + bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const { |
| 672 | + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && |
| 673 | + equal(indices, other->indices); |
| 674 | + } |
| 675 | + |
| 676 | + void SHashReduce(SHashReducer hash_reduce) const { |
| 677 | + hash_reduce(dtype); |
| 678 | + hash_reduce(buffer); |
| 679 | + hash_reduce(indices); |
| 680 | + } |
| 681 | + |
| 682 | + static constexpr const char* _type_key = "tir.SparseBufferLoad"; |
| 683 | + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode); |
| 684 | +}; |
| 685 | + |
| 686 | +/*! |
| 687 | + * \brief Managed reference to SparseBufferLoadNode. |
| 688 | + * \sa SparseBufferLoadNode |
| 689 | + */ |
| 690 | +class SparseBufferLoad : public PrimExpr { |
| 691 | + public: |
| 692 | + TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices, |
| 693 | + Span span = Span()); |
| 694 | + |
| 695 | + TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode); |
| 696 | + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode); |
| 697 | +}; |
| 698 | + |
646 | 699 | /*! |
647 | 700 | * \brief Load value from the result produced by the producer. |
648 | 701 | * |
|
0 commit comments