Skip to content

Commit 8e70d89

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5)
* Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()`
1 parent ef197bc commit 8e70d89

File tree

7 files changed

+202
-26
lines changed

7 files changed

+202
-26
lines changed

include/tvm/tir/expr.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <tvm/runtime/container/string.h>
3535
#include <tvm/runtime/data_type.h>
3636
#include <tvm/tir/buffer.h>
37+
#include <tvm/tir/sparse.h>
3738
#include <tvm/tir/var.h>
3839

3940
#include <algorithm>
@@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr {
643644
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
644645
};
645646

647+
/*!
648+
* \brief Load value from the high dimension sparse buffer.
649+
*
650+
* \code
651+
*
652+
* value = buffer[i, j];
653+
*
654+
* \endcode
655+
* \sa SparseBufferStore
656+
*/
657+
class SparseBufferLoadNode : public PrimExprNode {
658+
public:
659+
/*! \brief The buffer variable. */
660+
SparseBuffer buffer;
661+
/*! \brief The indices location to be loaded. */
662+
Array<PrimExpr> indices;
663+
664+
void VisitAttrs(AttrVisitor* v) {
665+
v->Visit("dtype", &(this->dtype));
666+
v->Visit("buffer", &buffer);
667+
v->Visit("indices", &indices);
668+
v->Visit("span", &span);
669+
}
670+
671+
bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const {
672+
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
673+
equal(indices, other->indices);
674+
}
675+
676+
void SHashReduce(SHashReducer hash_reduce) const {
677+
hash_reduce(dtype);
678+
hash_reduce(buffer);
679+
hash_reduce(indices);
680+
}
681+
682+
static constexpr const char* _type_key = "tir.SparseBufferLoad";
683+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode);
684+
};
685+
686+
/*!
687+
* \brief Managed reference to SparseBufferLoadNode.
688+
* \sa SparseBufferLoadNode
689+
*/
690+
class SparseBufferLoad : public PrimExpr {
691+
public:
692+
TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices,
693+
Span span = Span());
694+
695+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode);
696+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode);
697+
};
698+
646699
/*!
647700
* \brief Load value from the result produced by the producer.
648701
*

include/tvm/tir/sparse.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
namespace tvm {
3434
namespace tir {
35-
namespace sparse {
3635

3736
/*!
3837
* \brief Base type for axis in sparse formats.
@@ -308,28 +307,36 @@ class SparseBufferNode : public Object {
308307
AxisTree tree;
309308
/* Axes */
310309
Array<Axis> axes;
311-
/* Number of dimensions */
312-
int ndim;
313310
/* Buffer corresponding to flattened value */
314311
Buffer data;
312+
/* Buffer Name */
313+
String name;
314+
/* Data type */
315+
runtime::DataType dtype;
316+
317+
inline int ndim() const {
318+
return static_cast<int>(axes.size());
319+
}
315320

316321
void VisitAttrs(AttrVisitor* v) {
317322
v->Visit("name", &tree);
318323
v->Visit("length", &axes);
319-
v->Visit("indptr", &ndim);
320324
v->Visit("num_cols", &data);
325+
v->Visit("name", &name);
326+
v->Visit("dtype", &dtype);
321327
}
322328

323329
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
324-
return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
325-
equal(data, other->data);
330+
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
331+
equal(name, other->name) && equal(dtype, other->dtype);
326332
}
327333

328334
void SHashReduce(SHashReducer hash_reduce) const {
329335
hash_reduce(tree);
330336
hash_reduce(axes);
331-
hash_reduce(ndim);
332337
hash_reduce(data);
338+
hash_reduce(name);
339+
hash_reduce(dtype);
333340
}
334341

335342
static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
@@ -342,12 +349,12 @@ class SparseBufferNode : public Object {
342349
*/
343350
class SparseBuffer : public ObjectRef {
344351
public:
345-
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim, Buffer data);
352+
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
353+
DataType dtype);
346354

347355
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
348356
};
349357

350-
} // namespace sparse
351358
} // namespace tir
352359
} // namespace tvm
353360

include/tvm/tir/stmt.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,60 @@ class BufferStore : public Stmt {
327327
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
328328
};
329329

330+
/*!
331+
* \brief Store value to the high dimension sparse buffer.
332+
*
333+
* \code
334+
*
335+
* buffer[i, j] = value;
336+
*
337+
* \endcode
338+
* \sa SparseBufferLoad
339+
*/
340+
class SparseBufferStoreNode : public StmtNode {
341+
public:
342+
/*! \brief The buffer variable. */
343+
SparseBuffer buffer;
344+
/*! \brief The value to be stored. */
345+
PrimExpr value;
346+
/*! \brief The indices location to be stored. */
347+
Array<PrimExpr> indices;
348+
349+
void VisitAttrs(AttrVisitor* v) {
350+
v->Visit("buffer", &buffer);
351+
v->Visit("value", &value);
352+
v->Visit("indices", &indices);
353+
v->Visit("span", &span);
354+
}
355+
356+
bool SEqualReduce(const SparseBufferStoreNode* other, SEqualReducer equal) const {
357+
return equal(buffer, other->buffer) && equal(value, other->value) &&
358+
equal(indices, other->indices);
359+
}
360+
361+
void SHashReduce(SHashReducer hash_reduce) const {
362+
hash_reduce(buffer);
363+
hash_reduce(value);
364+
hash_reduce(indices);
365+
}
366+
367+
static constexpr const char* _type_key = "tir.SparseBufferStore";
368+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferStoreNode, StmtNode);
369+
};
370+
371+
/*!
372+
* \brief Managed reference to SparseBufferStoreNode.
373+
* \sa SparseBufferStoreNode
374+
*/
375+
class SparseBufferStore : public Stmt {
376+
public:
377+
TVM_DLL explicit SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
378+
Span span = Span());
379+
380+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferStore, Stmt, SparseBufferStoreNode);
381+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferStoreNode);
382+
};
383+
330384
/*!
331385
* \brief Annotate the region where the buffer need to
332386
* be read and write in the body.

python/tvm/tir/sparse.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,23 @@ class SparseBuffer:
177177
axes : List[Axis]
178178
The axes of the sparse buffer
179179
180-
ndim : int
181-
The number of dimensions of the sparse buffer
182-
183180
data : Buffer
184181
The data of the sparse buffer
182+
183+
name : str
184+
The name of the sparse buffer
185+
186+
dtype : Optional[str]
187+
The data type of the sparse buffer
185188
"""
186189

187190
tree: AxisTree
188191
axes: List[Axis]
189-
ndim: int
190192
data: Buffer
193+
name: str
191194

192-
def __init__(self, tree, axes, ndim, data):
195+
def __init__(self, tree, axes, data, name, dtype=None):
196+
dtype = "float32" if dtype is None else dtype
193197
self.__init_handle_by_constructor__(
194-
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
198+
_ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore
195199
)

src/tir/ir/expr.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
10841084
p->stream << "]";
10851085
});
10861086

1087+
// SparseBufferLoad
1088+
SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
1089+
ObjectPtr<SparseBufferLoadNode> node = make_object<SparseBufferLoadNode>();
1090+
node->dtype = buffer->dtype;
1091+
node->buffer = std::move(buffer);
1092+
node->indices = std::move(indices);
1093+
node->span = std::move(span);
1094+
data_ = std::move(node);
1095+
}
1096+
1097+
TVM_REGISTER_GLOBAL("tir.SparseBufferLoad")
1098+
.set_body_typed([](SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
1099+
return SparseBufferLoad(buffer, indices, span);
1100+
});
1101+
1102+
TVM_REGISTER_NODE_TYPE(SparseBufferLoadNode);
1103+
1104+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1105+
.set_dispatch<SparseBufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
1106+
auto* op = static_cast<const SparseBufferLoadNode*>(node.get());
1107+
p->stream << op->buffer->name << "[";
1108+
for (size_t i = 0; i < op->indices.size(); ++i) {
1109+
p->Print(op->indices[i]);
1110+
if (i < op->indices.size() - 1) {
1111+
p->stream << ", ";
1112+
}
1113+
}
1114+
p->stream << "]";
1115+
});
1116+
10871117
// ProducerLoad
10881118
ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) {
10891119
ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();

src/tir/ir/sparse.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
namespace tvm {
2929
namespace tir {
3030

31-
namespace sparse {
32-
33-
3431
// DenseFixedAxis
3532
DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
3633
ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
@@ -148,25 +145,23 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
148145
});
149146

150147
// SparseBuffer
151-
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim,
152-
Buffer data) {
148+
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
149+
DataType dtype) {
153150
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
154151
node->tree = std::move(tree);
155152
node->axes = std::move(axes);
156-
node->ndim = ndim;
157153
node->data = std::move(data);
154+
node->name = std::move(name);
155+
node->dtype = dtype;
158156
data_ = std::move(node);
159157
}
160158

161159
TVM_REGISTER_NODE_TYPE(SparseBufferNode);
162160

163161
TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
164-
.set_body_typed([](AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
165-
// Todo(@ruihang): to be revised later
166-
return SparseBuffer(root, axes, ndim, data);
162+
.set_body_typed([](AxisTree tree, Array<Axis> axes, Buffer data, String name, DataType dtype) {
163+
return SparseBuffer(tree, axes, data, name, dtype);
167164
});
168165

169-
} // namespace sparse
170-
171166
} // namespace tir
172167
} // namespace tvm

src/tir/ir/stmt.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,39 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
618618
p->stream << '\n';
619619
});
620620

621+
// SparseBufferStore
622+
SparseBufferStore::SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
623+
Span span) {
624+
ObjectPtr<SparseBufferStoreNode> node = make_object<SparseBufferStoreNode>();
625+
node->buffer = std::move(buffer);
626+
node->value = std::move(value);
627+
node->indices = std::move(indices);
628+
node->span = std::move(span);
629+
data_ = std::move(node);
630+
}
631+
632+
TVM_REGISTER_GLOBAL("tir.SparseBufferStore")
633+
.set_body_typed([](SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
634+
return SparseBufferStore(buffer, value, indices, span);
635+
});
636+
637+
TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);
638+
639+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
640+
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
641+
auto* op = static_cast<const BufferStoreNode*>(node.get());
642+
p->PrintIndent();
643+
p->stream << op->buffer->name << "[";
644+
for (size_t i = 0; i < op->indices.size(); ++i) {
645+
p->Print(op->indices[i]);
646+
if (i < op->indices.size() - 1) p->stream << ", ";
647+
}
648+
p->stream << "]";
649+
p->stream << " = ";
650+
p->Print(op->value);
651+
p->stream << '\n';
652+
});
653+
621654
// BufferRealize
622655
BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
623656
Span span) {

0 commit comments

Comments
 (0)