2424#include < tvm/arith/analyzer.h>
2525#include < tvm/runtime/registry.h>
2626#include < tvm/tir/buffer.h>
27+ #include < tvm/tir/op.h>
2728#include < tvm/tir/sparse.h>
2829
2930namespace tvm {
@@ -43,19 +44,27 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis)
4344 return DLDataType2String (axis->GetIndexType ());
4445});
4546
46- TVM_REGISTER_GLOBAL (" tir.sparse.GetNNZ" ).set_body_typed([](Axis axis) { return axis->nnz (); });
47+ TVM_REGISTER_GLOBAL (" tir.sparse.GetNNZ" ).set_body_typed([](Axis axis) { return axis->GetNNZ (); });
4748
4849/* ******* AxisNode ********/
4950
50- /* ! \brief Implementation of get root axis function. */
51- Axis AxisNode::GetRootAxis () const {
52- Optional<Axis> parent = GetParentAxis ();
53- if (parent.defined ()) {
54- return parent.value ()->GetRootAxis ();
51+ std::tuple<PrimExpr, PrimExpr> AxisNode::GetOffsetExtent (SparseCtx* ctx) const {
52+ auto prev = ctx->GetPrevAxis (GetRef<Axis>(this ));
53+ if (prev.defined ()) {
54+ Axis prev_axis = prev.value ();
55+ PrimExpr lb = Aggregate (ctx, 0 );
56+ PrimExpr orig_prev_coordinate = ctx->GetCoordinate (prev_axis),
57+ orig_prev_offset = ctx->GetOffset (prev_axis);
58+ ctx->SetCoordinate (prev_axis, orig_prev_coordinate + 1 );
59+ ctx->SetOffset (prev_axis, orig_prev_offset + 1 );
60+ PrimExpr ub = Aggregate (ctx, 0 );
61+ ctx->SetCoordinate (prev_axis, orig_prev_coordinate);
62+ ctx->SetOffset (prev_axis, orig_prev_offset);
63+ return {lb, ub};
5564 } else {
56- return GetRef<Axis>( this ) ;
65+ return { Integer ( 0 ), GetNNZ ()} ;
5766 }
58- }
67+ };
5968
6069/* ******* DenseFixedAxis ********/
6170
@@ -67,6 +76,23 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
6776 data_ = std::move (node);
6877}
6978
79+ PrimExpr DenseFixedAxisNode::Aggregate (SparseCtx* ctx, PrimExpr index) const {
80+ auto try_prev = ctx->GetPrevAxis (GetRef<Axis>(this ));
81+ if (try_prev.defined ()) {
82+ Axis prev_axis = try_prev.value ();
83+ PrimExpr prev_offset = ctx->GetOffset (prev_axis);
84+ return prev_offset * length + std::move (index);
85+ } else {
86+ return index;
87+ }
88+ }
89+
90+ PrimExpr DenseFixedAxisNode::Compress (SparseCtx* ctx, PrimExpr coordinate) const {
91+ return coordinate;
92+ }
93+
94+ PrimExpr DenseFixedAxisNode::Decompress (SparseCtx* ctx, PrimExpr offset) const { return offset; }
95+
7096TVM_REGISTER_NODE_TYPE (DenseFixedAxisNode);
7197
7298TVM_REGISTER_GLOBAL (" tir.sparse.DenseFixedAxis" ).set_body_typed([](String name, PrimExpr length) {
@@ -112,11 +138,11 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
112138
113139 ObjectPtr<FusedAxisNode> node = make_object<FusedAxisNode>();
114140 std::string fused_name = group[0 ]->name ;
115- for (int i = 1 ; i < group.size (); ++i) {
141+ for (size_t i = 1 ; i < group.size (); ++i) {
116142 fused_name += group[i]->name ;
117143 }
118144 node->name = " fused_" + fused_name + " _" + group[index]->name ;
119- node->length = group[index]->nnz ();
145+ node->length = group[index]->GetNNZ ();
120146 node->group = std::move (group);
121147 node->index = index;
122148 data_ = std::move (node);
@@ -146,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
146172
147173/* ******* DenseVariableAxis ********/
148174
149- /* ! \brief Default constuctor of DenseVariableAxis */
175+ /* ! \brief Default constructor of DenseVariableAxis */
150176DenseVariableAxis::DenseVariableAxis (String name, Axis parent, PrimExpr length, PrimExpr nnz,
151177 Buffer indptr) {
152178 ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
@@ -158,6 +184,18 @@ DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length,
158184 data_ = std::move (node);
159185}
160186
187+ PrimExpr DenseVariableAxisNode::Aggregate (SparseCtx* ctx, PrimExpr index) const {
188+ Axis prev_axis = ctx->GetPrevAxis (GetRef<Axis>(this )).value ();
189+ PrimExpr prev_offset = ctx->GetOffset (prev_axis);
190+ return add (BufferLoad (indptr, {std::move (prev_offset)}), std::move (index));
191+ }
192+
193+ PrimExpr DenseVariableAxisNode::Compress (SparseCtx* ctx, PrimExpr coordinate) const {
194+ return coordinate;
195+ }
196+
197+ PrimExpr DenseVariableAxisNode::Decompress (SparseCtx* ctx, PrimExpr offset) const { return offset; }
198+
161199TVM_REGISTER_NODE_TYPE (DenseVariableAxisNode);
162200
163201TVM_REGISTER_GLOBAL (" tir.sparse.DenseVariableAxis" )
@@ -186,6 +224,34 @@ AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Bu
186224 data_ = std::move (node);
187225}
188226
227+ PrimExpr AttachedAxisNode::Aggregate (SparseCtx* ctx, PrimExpr index) const {
228+ PrimExpr parent_offset = ctx->GetOffset (parent_);
229+ PrimExpr base_offset = BufferLoad (indptr, {parent_offset});
230+ PrimExpr accum_offset = Integer (0 );
231+ PrimExpr length = Integer (0 );
232+ Array<Axis> collect_axes;
233+ Array<PrimExpr> collect_coordinates;
234+ Axis axis;
235+ for (axis = GetRef<Axis>(this ); axis->kind () == AxisKind::kDenseVariable ;
236+ axis = ctx->GetPrevAxis (axis).value ()) {
237+ collect_axes.push_back (axis);
238+ collect_coordinates.push_back (ctx->GetCoordinate (axis));
239+ }
240+ ICHECK (axis.get () == parent_.get ())
241+ << " The root of attached axis should be the same as stored parent axis." ;
242+ for (int i = collect_axes.size () - 1 ; i != 0 ; --i) {
243+ Axis axis = std::move (collect_axes[i]);
244+ auto * ptr = axis.as <DenseVariableAxisNode>();
245+ ICHECK (ptr != nullptr )
246+ << " Each attached axis except for the root must be a dense variable axis" ;
247+ PrimExpr coordinate = std::move (collect_coordinates[i]);
248+ accum_offset = accum_offset * length + coordinate;
249+ length =
250+ BufferLoad (ptr->indptr , {parent_offset + 1 }) - BufferLoad (ptr->indptr , {parent_offset});
251+ }
252+ return base_offset + accum_offset;
253+ }
254+
189255TVM_REGISTER_NODE_TYPE (AttachedAxisNode);
190256
191257TVM_REGISTER_GLOBAL (" tir.sparse.AttachedAxis" )
@@ -215,6 +281,22 @@ SparseFixedAxis::SparseFixedAxis(String name, Axis parent, PrimExpr length, Buff
215281 data_ = std::move (node);
216282}
217283
284+ PrimExpr SparseFixedAxisNode::Aggregate (SparseCtx* ctx, PrimExpr index) const {
285+ Axis prev_axis = ctx->GetPrevAxis (GetRef<Axis>(this )).value ();
286+ PrimExpr prev_offset = ctx->GetOffset (prev_axis);
287+ return std::move (prev_offset) * nnz_cols + std::move (index);
288+ }
289+
290+ PrimExpr SparseFixedAxisNode::Compress (SparseCtx* ctx, PrimExpr coordinate) const {
291+ PrimExpr lb, ub;
292+ std::tie (lb, ub) = GetOffsetExtent (ctx);
293+ return lower_bound (indices->data , coordinate, lb, ub) - lb;
294+ }
295+
296+ PrimExpr SparseFixedAxisNode::Decompress (SparseCtx* ctx, PrimExpr offset) const {
297+ return BufferLoad (indices, {offset});
298+ }
299+
218300TVM_REGISTER_NODE_TYPE (SparseFixedAxisNode);
219301
220302TVM_REGISTER_GLOBAL (" tir.sparse.SparseFixedAxis" )
@@ -227,8 +309,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
227309TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
228310 .set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
229311 auto * op = static_cast <const SparseFixedAxisNode*>(node.get ());
230- p->stream << " sparse_fixed(" << op->name << " , " << op->GetParentAxis (). value () ->name << " , "
231- << op-> length << " , " << op->nnz_cols << " , " << op->indices ->name << " )" ;
312+ p->stream << " sparse_fixed(" << op->name << " , " << op->parent_ ->name << " , " << op-> length
313+ << " , " << op->nnz_cols << " , " << op->indices ->name << " )" ;
232314 });
233315
234316/* ******* SparseVariableAxis ********/
@@ -245,6 +327,22 @@ SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length
245327 data_ = std::move (node);
246328}
247329
330+ PrimExpr SparseVariableAxisNode::Aggregate (SparseCtx* ctx, PrimExpr index) const {
331+ Axis prev_axis = ctx->GetPrevAxis (GetRef<Axis>(this )).value ();
332+ PrimExpr prev_offset = ctx->GetOffset (prev_axis);
333+ return add (BufferLoad (indptr, {std::move (prev_offset)}), std::move (index));
334+ }
335+
336+ PrimExpr SparseVariableAxisNode::Compress (SparseCtx* ctx, PrimExpr coordinate) const {
337+ PrimExpr lb, ub;
338+ std::tie (lb, ub) = GetOffsetExtent (ctx);
339+ return lower_bound (indices->data , coordinate, lb, ub) - lb;
340+ }
341+
342+ PrimExpr SparseVariableAxisNode::Decompress (SparseCtx* ctx, PrimExpr offset) const {
343+ return BufferLoad (indices, {offset});
344+ }
345+
248346TVM_REGISTER_NODE_TYPE (SparseVariableAxisNode);
249347
250348TVM_REGISTER_GLOBAL (" tir.sparse.SparseVariableAxis" )
0 commit comments