Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ enum class AxisKind : int {

class Axis;

/*! \brief Common interface for both SparseBlockCtx and SparseBufferAccessCtx. */
class SparseCtx {
public:
virtual Optional<Axis> GetPrevAxis(Axis axis) const = 0;
virtual PrimExpr GetCoordinate(Axis axis) const = 0;
virtual PrimExpr GetOffset(Axis axis) const = 0;
virtual void SetCoordinate(Axis axis, PrimExpr coordinate) = 0;
virtual void SetOffset(Axis axis, PrimExpr index) = 0;
};

/*!
* \brief Base type for axis in sparse formats.
*/
Expand Down Expand Up @@ -71,10 +81,14 @@ class AxisNode : public Object {
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }
virtual Optional<Axis> GetParentAxis() const = 0;
Axis GetRootAxis() const;

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;
virtual PrimExpr GetNNZ() const = 0;

virtual PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const = 0;
virtual PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const = 0;
virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const = 0;
std::tuple<PrimExpr, PrimExpr> GetOffsetExtent(SparseCtx* ctx) const;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -134,10 +148,16 @@ class DenseFixedAxisNode : public DenseAxisNode {
public:
AxisKind kind() const final { return AxisKind::kDenseFixed; }

PrimExpr nnz() const final { return length; }
PrimExpr GetNNZ() const final { return length; }

Optional<Axis> GetParentAxis() const final { return NullOpt; }

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -221,7 +241,7 @@ class FusedAxisNode : public DenseFixedAxisNode {
};

/*!
* \brief Managed refenrence to FusedAxisNode.
* \brief Managed reference to FusedAxisNode.
* \sa FusedAxisNode
*/
class FusedAxis : public DenseFixedAxis {
Expand Down Expand Up @@ -257,10 +277,16 @@ class DenseVariableAxisNode : public DenseAxisNode {

AxisKind kind() const final { return AxisKind::kDenseVariable; }

PrimExpr nnz() const final { return nnz_; }
PrimExpr GetNNZ() const final { return nnz_; }

Optional<Axis> GetParentAxis() const final { return parent_; }

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand All @@ -287,6 +313,8 @@ class AttachedAxisNode : public DenseVariableAxisNode {

Axis GetOriginalAxis() const { return orig_; }

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.AttachedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode);
};
Expand All @@ -307,9 +335,9 @@ class AttachedAxis : public DenseVariableAxis {
class SparseFixedAxisNode : public SparseAxisNode {
public:
Buffer indices;
Axis parent_;
/* fixed number of non-zero columns of current sparse axis. */
PrimExpr nnz_cols;
Axis parent_;

void VisitAttrs(AttrVisitor* v) {
SparseAxisNode::VisitAttrs(v);
Expand All @@ -328,12 +356,18 @@ class SparseFixedAxisNode : public SparseAxisNode {
hash_reduce(nnz_cols);
}

PrimExpr nnz() const { return indices->shape[0]; }
PrimExpr GetNNZ() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseFixed; }

Optional<Axis> GetParentAxis() const final { return parent_; }

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand Down Expand Up @@ -376,12 +410,18 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

PrimExpr nnz() const { return indices->shape[0]; }
PrimExpr GetNNZ() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseVariable; }

Optional<Axis> GetParentAxis() const final { return parent_; }

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};
Expand Down
6 changes: 3 additions & 3 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1406,12 +1406,12 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
if (const auto* attached_axis = obj.as<AttachedAxisNode>()) {
ICHECK_EQ(params.size(), 1);
doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name
<< ", " << Print(attached_axis->nnz()) << ", " << Print(params[0]) << ", "
<< ", " << Print(attached_axis->GetNNZ()) << ", " << Print(params[0]) << ", "
<< PrintDType(attached_axis->indptr->dtype) << ")";
} else {
ICHECK_EQ(params.size(), 1);
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length)
<< ", " << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
<< ", " << Print(dv_axis->GetNNZ()) << "), " << Print(params[0]) << ", "
<< PrintDType(dv_axis->indptr->dtype) << ")";
}
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
Expand All @@ -1422,7 +1422,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
ICHECK_EQ(params.size(), 2);
doc << "sparse_variable(" << sv_axis->parent_->name << ", (" << Print(sv_axis->length) << ", "
<< Print(sv_axis->nnz()) << "), (" << Print(params[0]) << ", " << Print(params[1])
<< Print(sv_axis->GetNNZ()) << "), (" << Print(params[0]) << ", " << Print(params[1])
<< "), " << PrintDType(sv_axis->indptr->dtype) << ")";
} else {
ICHECK(false) << "Cannot reach here";
Expand Down
124 changes: 111 additions & 13 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/sparse.h>

namespace tvm {
Expand All @@ -43,19 +44,27 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)
return DLDataType2String(axis->GetIndexType());
});

TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); });
TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->GetNNZ(); });

/******** AxisNode ********/

/*! \brief Implementation of get root axis function. */
Axis AxisNode::GetRootAxis() const {
Optional<Axis> parent = GetParentAxis();
if (parent.defined()) {
return parent.value()->GetRootAxis();
std::tuple<PrimExpr, PrimExpr> AxisNode::GetOffsetExtent(SparseCtx* ctx) const {
auto prev = ctx->GetPrevAxis(GetRef<Axis>(this));
if (prev.defined()) {
Axis prev_axis = prev.value();
PrimExpr lb = Aggregate(ctx, 0);
PrimExpr orig_prev_coordinate = ctx->GetCoordinate(prev_axis),
orig_prev_offset = ctx->GetOffset(prev_axis);
ctx->SetCoordinate(prev_axis, orig_prev_coordinate + 1);
ctx->SetOffset(prev_axis, orig_prev_offset + 1);
PrimExpr ub = Aggregate(ctx, 0);
ctx->SetCoordinate(prev_axis, orig_prev_coordinate);
ctx->SetOffset(prev_axis, orig_prev_offset);
return {lb, ub};
} else {
return GetRef<Axis>(this);
return {Integer(0), GetNNZ()};
}
}
};

/******** DenseFixedAxis ********/

Expand All @@ -67,6 +76,23 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
data_ = std::move(node);
}

PrimExpr DenseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
auto try_prev = ctx->GetPrevAxis(GetRef<Axis>(this));
if (try_prev.defined()) {
Axis prev_axis = try_prev.value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return prev_offset * length + std::move(index);
} else {
return index;
}
}

PrimExpr DenseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
return coordinate;
}

PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }

TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
Expand Down Expand Up @@ -112,11 +138,11 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {

ObjectPtr<FusedAxisNode> node = make_object<FusedAxisNode>();
std::string fused_name = group[0]->name;
for (int i = 1; i < group.size(); ++i) {
for (size_t i = 1; i < group.size(); ++i) {
fused_name += group[i]->name;
}
node->name = "fused_" + fused_name + "_" + group[index]->name;
node->length = group[index]->nnz();
node->length = group[index]->GetNNZ();
node->group = std::move(group);
node->index = index;
data_ = std::move(node);
Expand Down Expand Up @@ -146,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
/*! \brief Default constructor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
Expand All @@ -158,6 +184,18 @@ DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length,
data_ = std::move(node);
}

PrimExpr DenseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
}

PrimExpr DenseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
return coordinate;
}

PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
Expand Down Expand Up @@ -186,6 +224,34 @@ AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Bu
data_ = std::move(node);
}

PrimExpr AttachedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
PrimExpr parent_offset = ctx->GetOffset(parent_);
PrimExpr base_offset = BufferLoad(indptr, {parent_offset});
PrimExpr accum_offset = Integer(0);
PrimExpr length = Integer(0);
Array<Axis> collect_axes;
Array<PrimExpr> collect_coordinates;
Axis axis;
for (axis = GetRef<Axis>(this); axis->kind() == AxisKind::kDenseVariable;
axis = ctx->GetPrevAxis(axis).value()) {
collect_axes.push_back(axis);
collect_coordinates.push_back(ctx->GetCoordinate(axis));
}
ICHECK(axis.get() == parent_.get())
<< "The root of attached axis should be the same as stored parent axis.";
for (int i = collect_axes.size() - 1; i != 0; --i) {
Axis axis = std::move(collect_axes[i]);
auto* ptr = axis.as<DenseVariableAxisNode>();
ICHECK(ptr != nullptr)
<< "Each attached axis except for the root must be a dense variable axis";
PrimExpr coordinate = std::move(collect_coordinates[i]);
accum_offset = accum_offset * length + coordinate;
length =
BufferLoad(ptr->indptr, {parent_offset + 1}) - BufferLoad(ptr->indptr, {parent_offset});
}
return base_offset + accum_offset;
}

TVM_REGISTER_NODE_TYPE(AttachedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis")
Expand Down Expand Up @@ -215,6 +281,22 @@ SparseFixedAxis::SparseFixedAxis(String name, Axis parent, PrimExpr length, Buff
data_ = std::move(node);
}

PrimExpr SparseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return std::move(prev_offset) * nnz_cols + std::move(index);
}

PrimExpr SparseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
PrimExpr lb, ub;
std::tie(lb, ub) = GetOffsetExtent(ctx);
return lower_bound(indices->data, coordinate, lb, ub) - lb;
}

PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
return BufferLoad(indices, {offset});
}

TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
Expand All @@ -227,8 +309,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SparseFixedAxisNode*>(node.get());
p->stream << "sparse_fixed(" << op->name << ", " << op->GetParentAxis().value()->name << ", "
<< op->length << ", " << op->nnz_cols << ", " << op->indices->name << ")";
p->stream << "sparse_fixed(" << op->name << ", " << op->parent_->name << ", " << op->length
<< ", " << op->nnz_cols << ", " << op->indices->name << ")";
});

/******** SparseVariableAxis ********/
Expand All @@ -245,6 +327,22 @@ SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length
data_ = std::move(node);
}

PrimExpr SparseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
}

PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
PrimExpr lb, ub;
std::tie(lb, ub) = GetOffsetExtent(ctx);
return lower_bound(indices->data, coordinate, lb, ub) - lb;
}

PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
return BufferLoad(indices, {offset});
}

TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
Expand Down
Loading