Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
#include <tvm/tir/var.h>

#include <algorithm>
Expand Down Expand Up @@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
};

/*!
* \brief Load value from the high dimension sparse buffer.
*
* \code
*
* value = buffer[i, j];
*
* \endcode
* \sa SparseBufferStore
*/
class SparseBufferLoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
SparseBuffer buffer;
/*! \brief The indices location to be loaded. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("buffer", &buffer);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(buffer);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.SparseBufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode);
};

/*!
* \brief Managed reference to SparseBufferLoadNode.
* \sa SparseBufferLoadNode
*/
class SparseBufferLoad : public PrimExpr {
public:
TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode);
};

/*!
* \brief Load value from the result produced by the producer.
*
Expand Down
25 changes: 16 additions & 9 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

namespace tvm {
namespace tir {
namespace sparse {

/*!
* \brief Base type for axis in sparse formats.
Expand Down Expand Up @@ -308,28 +307,36 @@ class SparseBufferNode : public Object {
AxisTree tree;
/* Axes */
Array<Axis> axes;
/* Number of dimensions */
int ndim;
/* Buffer corresponding to flattened value */
Buffer data;
/* Buffer Name */
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding a function ndim() that returns the length of axes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Will do.

String name;
/* Data type */
runtime::DataType dtype;

inline int ndim() const {
return static_cast<int>(axes.size());
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
v->Visit("length", &axes);
v->Visit("indptr", &ndim);
v->Visit("num_cols", &data);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
equal(data, other->data);
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name) && equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tree);
hash_reduce(axes);
hash_reduce(ndim);
hash_reduce(data);
hash_reduce(name);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -342,12 +349,12 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim, Buffer data);
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

} // namespace sparse
} // namespace tir
} // namespace tvm

Expand Down
54 changes: 54 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,60 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
* \brief Store value to the high dimension sparse buffer.
*
* \code
*
* buffer[i, j] = value;
*
* \endcode
* \sa SparseBufferLoad
*/
class SparseBufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
SparseBuffer buffer;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
v->Visit("span", &span);
}

bool SEqualReduce(const SparseBufferStoreNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(value, other->value) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.SparseBufferStore";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferStoreNode, StmtNode);
};

/*!
* \brief Managed reference to SparseBufferStoreNode.
* \sa SparseBufferStoreNode
*/
class SparseBufferStore : public Stmt {
public:
TVM_DLL explicit SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferStore, Stmt, SparseBufferStoreNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferStoreNode);
};

/*!
* \brief Annotate the region where the buffer need to
* be read and write in the body.
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,23 @@ class SparseBuffer:
axes : List[Axis]
The axes of the sparse buffer

ndim : int
The number of dimensions of the sparse buffer

data : Buffer
The data of the sparse buffer

name : str
The name of the sparse buffer

dtype : Optional[str]
The data type of the sparse buffer
"""

tree: AxisTree
axes: List[Axis]
ndim: int
data: Buffer
name: str

def __init__(self, tree, axes, ndim, data):
def __init__(self, tree, axes, data, name, dtype=None):
dtype = "float32" if dtype is None else dtype
self.__init_handle_by_constructor__(
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
_ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore
)
30 changes: 30 additions & 0 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "]";
});

// SparseBufferLoad
SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
ObjectPtr<SparseBufferLoadNode> node = make_object<SparseBufferLoadNode>();
node->dtype = buffer->dtype;
node->buffer = std::move(buffer);
node->indices = std::move(indices);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SparseBufferLoad")
.set_body_typed([](SparseBuffer buffer, Array<PrimExpr> indices, Span span) {
return SparseBufferLoad(buffer, indices, span);
});

TVM_REGISTER_NODE_TYPE(SparseBufferLoadNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SparseBufferLoadNode*>(node.get());
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});

// ProducerLoad
ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) {
ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
Expand Down
17 changes: 6 additions & 11 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
namespace tvm {
namespace tir {

namespace sparse {


// DenseFixedAxis
DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
Expand Down Expand Up @@ -148,25 +145,23 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
});

// SparseBuffer
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim,
Buffer data) {
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype) {
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
node->tree = std::move(tree);
node->axes = std::move(axes);
node->ndim = ndim;
node->data = std::move(data);
node->name = std::move(name);
node->dtype = dtype;
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
// Todo(@ruihang): to be revised later
return SparseBuffer(root, axes, ndim, data);
.set_body_typed([](AxisTree tree, Array<Axis> axes, Buffer data, String name, DataType dtype) {
return SparseBuffer(tree, axes, data, name, dtype);
});

} // namespace sparse

} // namespace tir
} // namespace tvm
33 changes: 33 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,39 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << '\n';
});

// SparseBufferStore
SparseBufferStore::SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices,
Span span) {
ObjectPtr<SparseBufferStoreNode> node = make_object<SparseBufferStoreNode>();
node->buffer = std::move(buffer);
node->value = std::move(value);
node->indices = std::move(indices);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SparseBufferStore")
.set_body_typed([](SparseBuffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
return SparseBufferStore(buffer, value, indices, span);
});

TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) p->stream << ", ";
}
p->stream << "]";
p->stream << " = ";
p->Print(op->value);
p->stream << '\n';
});

// BufferRealize
BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
Span span) {
Expand Down