Skip to content

Commit 16b6e8f

Browse files
authored
Support indices lowering for attach and fuse. (#43)
* upd * upd * upd
1 parent 213ec15 commit 16b6e8f

File tree

5 files changed

+341
-226
lines changed

5 files changed

+341
-226
lines changed

include/tvm/tir/sparse.h

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ enum class AxisKind : int {
4242

4343
class Axis;
4444

45+
/*! \brief Common interface for both SparseBlockCtx and SparseBufferAccessCtx. */
46+
class SparseCtx {
47+
public:
48+
virtual Optional<Axis> GetPrevAxis(Axis axis) const = 0;
49+
virtual PrimExpr GetCoordinate(Axis axis) const = 0;
50+
virtual PrimExpr GetOffset(Axis axis) const = 0;
51+
virtual void SetCoordinate(Axis axis, PrimExpr coordinate) = 0;
52+
virtual void SetOffset(Axis axis, PrimExpr index) = 0;
53+
};
54+
4555
/*!
4656
* \brief Base type for axis in sparse formats.
4757
*/
@@ -71,10 +81,14 @@ class AxisNode : public Object {
7181
PrimExpr GetLength() const { return length; }
7282
DataType GetIndexType() const { return length->dtype; }
7383
virtual Optional<Axis> GetParentAxis() const = 0;
74-
Axis GetRootAxis() const;
7584

7685
virtual AxisKind kind() const = 0;
77-
virtual PrimExpr nnz() const = 0;
86+
virtual PrimExpr GetNNZ() const = 0;
87+
88+
virtual PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const = 0;
89+
virtual PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const = 0;
90+
virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const = 0;
91+
std::tuple<PrimExpr, PrimExpr> GetOffsetExtent(SparseCtx* ctx) const;
7892

7993
static constexpr const char* _type_key = "tir.sparse.Axis";
8094
static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -134,10 +148,16 @@ class DenseFixedAxisNode : public DenseAxisNode {
134148
public:
135149
AxisKind kind() const final { return AxisKind::kDenseFixed; }
136150

137-
PrimExpr nnz() const final { return length; }
151+
PrimExpr GetNNZ() const final { return length; }
138152

139153
Optional<Axis> GetParentAxis() const final { return NullOpt; }
140154

155+
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
156+
157+
PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;
158+
159+
PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
160+
141161
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
142162
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
143163
};
@@ -221,7 +241,7 @@ class FusedAxisNode : public DenseFixedAxisNode {
221241
};
222242

223243
/*!
224-
* \brief Managed refenrence to FusedAxisNode.
244+
* \brief Managed reference to FusedAxisNode.
225245
* \sa FusedAxisNode
226246
*/
227247
class FusedAxis : public DenseFixedAxis {
@@ -257,10 +277,16 @@ class DenseVariableAxisNode : public DenseAxisNode {
257277

258278
AxisKind kind() const final { return AxisKind::kDenseVariable; }
259279

260-
PrimExpr nnz() const final { return nnz_; }
280+
PrimExpr GetNNZ() const final { return nnz_; }
261281

262282
Optional<Axis> GetParentAxis() const final { return parent_; }
263283

284+
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
285+
286+
PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;
287+
288+
PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
289+
264290
static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
265291
TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
266292
};
@@ -287,6 +313,8 @@ class AttachedAxisNode : public DenseVariableAxisNode {
287313

288314
Axis GetOriginalAxis() const { return orig_; }
289315

316+
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
317+
290318
static constexpr const char* _type_key = "tir.sparse.AttachedAxis";
291319
TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode);
292320
};
@@ -307,9 +335,9 @@ class AttachedAxis : public DenseVariableAxis {
307335
class SparseFixedAxisNode : public SparseAxisNode {
308336
public:
309337
Buffer indices;
338+
Axis parent_;
310339
/* fixed number of non-zero columns of current sparse axis. */
311340
PrimExpr nnz_cols;
312-
Axis parent_;
313341

314342
void VisitAttrs(AttrVisitor* v) {
315343
SparseAxisNode::VisitAttrs(v);
@@ -328,12 +356,18 @@ class SparseFixedAxisNode : public SparseAxisNode {
328356
hash_reduce(nnz_cols);
329357
}
330358

331-
PrimExpr nnz() const { return indices->shape[0]; }
359+
PrimExpr GetNNZ() const { return indices->shape[0]; }
332360

333361
AxisKind kind() const final { return AxisKind::kSparseFixed; }
334362

335363
Optional<Axis> GetParentAxis() const final { return parent_; }
336364

365+
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
366+
367+
PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;
368+
369+
PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
370+
337371
static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
338372
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
339373
};
@@ -376,12 +410,18 @@ class SparseVariableAxisNode : public SparseAxisNode {
376410
hash_reduce(indices);
377411
}
378412

379-
PrimExpr nnz() const { return indices->shape[0]; }
413+
PrimExpr GetNNZ() const { return indices->shape[0]; }
380414

381415
AxisKind kind() const final { return AxisKind::kSparseVariable; }
382416

383417
Optional<Axis> GetParentAxis() const final { return parent_; }
384418

419+
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
420+
421+
PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;
422+
423+
PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
424+
385425
static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
386426
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
387427
};

src/printer/tvmscript_printer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,12 +1406,12 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
14061406
if (const auto* attached_axis = obj.as<AttachedAxisNode>()) {
14071407
ICHECK_EQ(params.size(), 1);
14081408
doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name
1409-
<< ", " << Print(attached_axis->nnz()) << ", " << Print(params[0]) << ", "
1409+
<< ", " << Print(attached_axis->GetNNZ()) << ", " << Print(params[0]) << ", "
14101410
<< PrintDType(attached_axis->indptr->dtype) << ")";
14111411
} else {
14121412
ICHECK_EQ(params.size(), 1);
14131413
doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length)
1414-
<< ", " << Print(dv_axis->nnz()) << "), " << Print(params[0]) << ", "
1414+
<< ", " << Print(dv_axis->GetNNZ()) << "), " << Print(params[0]) << ", "
14151415
<< PrintDType(dv_axis->indptr->dtype) << ")";
14161416
}
14171417
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
@@ -1422,7 +1422,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
14221422
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
14231423
ICHECK_EQ(params.size(), 2);
14241424
doc << "sparse_variable(" << sv_axis->parent_->name << ", (" << Print(sv_axis->length) << ", "
1425-
<< Print(sv_axis->nnz()) << "), (" << Print(params[0]) << ", " << Print(params[1])
1425+
<< Print(sv_axis->GetNNZ()) << "), (" << Print(params[0]) << ", " << Print(params[1])
14261426
<< "), " << PrintDType(sv_axis->indptr->dtype) << ")";
14271427
} else {
14281428
ICHECK(false) << "Cannot reach here";

src/tir/ir/sparse.cc

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <tvm/arith/analyzer.h>
2525
#include <tvm/runtime/registry.h>
2626
#include <tvm/tir/buffer.h>
27+
#include <tvm/tir/op.h>
2728
#include <tvm/tir/sparse.h>
2829

2930
namespace tvm {
@@ -43,19 +44,27 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)
4344
return DLDataType2String(axis->GetIndexType());
4445
});
4546

46-
TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->nnz(); });
47+
TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->GetNNZ(); });
4748

4849
/******** AxisNode ********/
4950

50-
/*! \brief Implementation of get root axis function. */
51-
Axis AxisNode::GetRootAxis() const {
52-
Optional<Axis> parent = GetParentAxis();
53-
if (parent.defined()) {
54-
return parent.value()->GetRootAxis();
51+
std::tuple<PrimExpr, PrimExpr> AxisNode::GetOffsetExtent(SparseCtx* ctx) const {
52+
auto prev = ctx->GetPrevAxis(GetRef<Axis>(this));
53+
if (prev.defined()) {
54+
Axis prev_axis = prev.value();
55+
PrimExpr lb = Aggregate(ctx, 0);
56+
PrimExpr orig_prev_coordinate = ctx->GetCoordinate(prev_axis),
57+
orig_prev_offset = ctx->GetOffset(prev_axis);
58+
ctx->SetCoordinate(prev_axis, orig_prev_coordinate + 1);
59+
ctx->SetOffset(prev_axis, orig_prev_offset + 1);
60+
PrimExpr ub = Aggregate(ctx, 0);
61+
ctx->SetCoordinate(prev_axis, orig_prev_coordinate);
62+
ctx->SetOffset(prev_axis, orig_prev_offset);
63+
return {lb, ub};
5564
} else {
56-
return GetRef<Axis>(this);
65+
return {Integer(0), GetNNZ()};
5766
}
58-
}
67+
};
5968

6069
/******** DenseFixedAxis ********/
6170

@@ -67,6 +76,23 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
6776
data_ = std::move(node);
6877
}
6978

79+
PrimExpr DenseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
80+
auto try_prev = ctx->GetPrevAxis(GetRef<Axis>(this));
81+
if (try_prev.defined()) {
82+
Axis prev_axis = try_prev.value();
83+
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
84+
return prev_offset * length + std::move(index);
85+
} else {
86+
return index;
87+
}
88+
}
89+
90+
PrimExpr DenseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
91+
return coordinate;
92+
}
93+
94+
PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }
95+
7096
TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);
7197

7298
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
@@ -112,11 +138,11 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
112138

113139
ObjectPtr<FusedAxisNode> node = make_object<FusedAxisNode>();
114140
std::string fused_name = group[0]->name;
115-
for (int i = 1; i < group.size(); ++i) {
141+
for (size_t i = 1; i < group.size(); ++i) {
116142
fused_name += group[i]->name;
117143
}
118144
node->name = "fused_" + fused_name + "_" + group[index]->name;
119-
node->length = group[index]->nnz();
145+
node->length = group[index]->GetNNZ();
120146
node->group = std::move(group);
121147
node->index = index;
122148
data_ = std::move(node);
@@ -146,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
146172

147173
/******** DenseVariableAxis ********/
148174

149-
/*! \brief Default constuctor of DenseVariableAxis */
175+
/*! \brief Default constructor of DenseVariableAxis */
150176
DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
151177
Buffer indptr) {
152178
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
@@ -158,6 +184,18 @@ DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length,
158184
data_ = std::move(node);
159185
}
160186

187+
PrimExpr DenseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
188+
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
189+
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
190+
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
191+
}
192+
193+
PrimExpr DenseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
194+
return coordinate;
195+
}
196+
197+
PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }
198+
161199
TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);
162200

163201
TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
@@ -186,6 +224,34 @@ AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Bu
186224
data_ = std::move(node);
187225
}
188226

227+
PrimExpr AttachedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
228+
PrimExpr parent_offset = ctx->GetOffset(parent_);
229+
PrimExpr base_offset = BufferLoad(indptr, {parent_offset});
230+
PrimExpr accum_offset = Integer(0);
231+
PrimExpr length = Integer(0);
232+
Array<Axis> collect_axes;
233+
Array<PrimExpr> collect_coordinates;
234+
Axis axis;
235+
for (axis = GetRef<Axis>(this); axis->kind() == AxisKind::kDenseVariable;
236+
axis = ctx->GetPrevAxis(axis).value()) {
237+
collect_axes.push_back(axis);
238+
collect_coordinates.push_back(ctx->GetCoordinate(axis));
239+
}
240+
ICHECK(axis.get() == parent_.get())
241+
<< "The root of attached axis should be the same as stored parent axis.";
242+
for (int i = collect_axes.size() - 1; i != 0; --i) {
243+
Axis axis = std::move(collect_axes[i]);
244+
auto* ptr = axis.as<DenseVariableAxisNode>();
245+
ICHECK(ptr != nullptr)
246+
<< "Each attached axis except for the root must be a dense variable axis";
247+
PrimExpr coordinate = std::move(collect_coordinates[i]);
248+
accum_offset = accum_offset * length + coordinate;
249+
length =
250+
BufferLoad(ptr->indptr, {parent_offset + 1}) - BufferLoad(ptr->indptr, {parent_offset});
251+
}
252+
return base_offset + accum_offset;
253+
}
254+
189255
TVM_REGISTER_NODE_TYPE(AttachedAxisNode);
190256

191257
TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis")
@@ -215,6 +281,22 @@ SparseFixedAxis::SparseFixedAxis(String name, Axis parent, PrimExpr length, Buff
215281
data_ = std::move(node);
216282
}
217283

284+
PrimExpr SparseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
285+
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
286+
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
287+
return std::move(prev_offset) * nnz_cols + std::move(index);
288+
}
289+
290+
PrimExpr SparseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
291+
PrimExpr lb, ub;
292+
std::tie(lb, ub) = GetOffsetExtent(ctx);
293+
return lower_bound(indices->data, coordinate, lb, ub) - lb;
294+
}
295+
296+
PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
297+
return BufferLoad(indices, {offset});
298+
}
299+
218300
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);
219301

220302
TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
@@ -227,8 +309,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
227309
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
228310
.set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
229311
auto* op = static_cast<const SparseFixedAxisNode*>(node.get());
230-
p->stream << "sparse_fixed(" << op->name << ", " << op->GetParentAxis().value()->name << ", "
231-
<< op->length << ", " << op->nnz_cols << ", " << op->indices->name << ")";
312+
p->stream << "sparse_fixed(" << op->name << ", " << op->parent_->name << ", " << op->length
313+
<< ", " << op->nnz_cols << ", " << op->indices->name << ")";
232314
});
233315

234316
/******** SparseVariableAxis ********/
@@ -245,6 +327,22 @@ SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length
245327
data_ = std::move(node);
246328
}
247329

330+
PrimExpr SparseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
331+
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
332+
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
333+
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
334+
}
335+
336+
PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
337+
PrimExpr lb, ub;
338+
std::tie(lb, ub) = GetOffsetExtent(ctx);
339+
return lower_bound(indices->data, coordinate, lb, ub) - lb;
340+
}
341+
342+
PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
343+
return BufferLoad(indices, {offset});
344+
}
345+
248346
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);
249347

250348
TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")

0 commit comments

Comments
 (0)