@@ -101,15 +101,12 @@ class SparseBlockCtx : public SparseCtx {
101101 struct Scope {
102102 explicit Scope (SparseBlock sp_block) : sp_block(std::move(sp_block)) {
103103 for (const SpIterVar& sp_iter_var : this ->sp_block ->sp_iter_vars ) {
104- axis2sp_iter.Set (sp_iter_var->axis , sp_iter_var);
105104 sp_iter_var_map.Set (sp_iter_var->var , sp_iter_var);
106105 }
107106 }
108107
109108 /* ! \brief The sparse block */
110109 SparseBlock sp_block;
111- /* ! \brief A mapping from axes to the sparse iterators that go over them */
112- Map<Axis, SpIterVar> axis2sp_iter;
113110 /* ! \brief A mapping from the internal variables of sparse iterators to the iterators */
114111 Map<Var, SpIterVar> sp_iter_var_map;
115112 /* ! \brief The stored offsets of the axis in the sparse block */
@@ -125,14 +122,40 @@ class SparseBlockCtx : public SparseCtx {
125122 stack_.emplace_back (GetRef<SparseBlock>(sp_block));
126123 /* Compute offsets and coordinates */
127124 size_t n_iters = sp_block->sp_iter_vars .size ();
128- for (size_t i = 0 ; i < n_iters; ++i ) {
125+ for (size_t i = 0 ; i < n_iters;) {
129126 SpIterVar sp_iter_var = sp_block->sp_iter_vars [i];
130127 Axis axis = sp_iter_var->axis ;
131128
132- PrimExpr offset = AggregateOffset (this , axis, sp_iter_var->var , ana_);
133- SetOffset (axis, offset);
134- PrimExpr coordinate = axis->Decompress (this , offset, sp_iter_var->var );
135- SetCoordinate (axis, coordinate);
129+ PrimExpr offset, index;
130+ if (auto fused_axis = axis.as <FusedAxisNode>()) {
131+ auto group = fused_axis->group ;
132+ offset = sp_block->sp_iter_vars [i + group.size () - 1 ]->var ;
133+ for (int j = group.size () - 1 ; j >= 0 ; --j) {
134+ Axis orig = group[j];
135+ SetOffset (orig, offset);
136+ if (j > 0 ) {
137+ // TODO(zihao): support more than sv axis.
138+ offset = lower_bound (Downcast<SparseVariableAxis>(orig)->indptr ->data , offset,
139+ Integer (0 ), orig->GetNNZ ());
140+ }
141+ }
142+ for (size_t j = 0 ; j < group.size (); ++j) {
143+ Axis orig = group[j];
144+ offset = GetOffset (orig);
145+ PrimExpr lb = std::get<0 >(orig->GetOffsetExtent (this ));
146+ index = offset - lb;
147+ PrimExpr coordinate = orig->Decompress (this , offset, index);
148+ SetCoordinate (orig, coordinate);
149+ i++;
150+ }
151+ } else {
152+ offset = AggregateOffset (this , axis, sp_iter_var->var , ana_);
153+ index = sp_iter_var->var ;
154+ PrimExpr coordinate = axis->Decompress (this , offset, index);
155+ SetOffset (axis, offset);
156+ SetCoordinate (axis, coordinate);
157+ i++;
158+ }
136159 }
137160 }
138161
@@ -164,7 +187,7 @@ class SparseBlockCtx : public SparseCtx {
164187 */
165188 PrimExpr GetOffset (Axis axis) const {
166189 Optional<PrimExpr> try_offset = top ()->cached_offsets .Get (axis);
167- CHECK (try_offset.defined ()) << " The offset of axis not defined yet." ;
190+ CHECK (try_offset.defined ()) << " The offset of axis " << axis-> name << " not defined yet." ;
168191 PrimExpr offset = try_offset.value ();
169192 return std::move (offset);
170193 }
@@ -202,7 +225,7 @@ class SparseBlockCtx : public SparseCtx {
202225 }
203226
204227 Optional<Axis> MatchAxis (SparseCtx* buf_ctx, Axis axis) {
205- if (!top ()->axis2sp_iter .Get (axis).defined ()) {
228+ if (!top ()->cached_offsets .Get (axis).defined ()) {
206229 return NullOpt;
207230 } else {
208231 Axis axis_ = axis;
@@ -233,7 +256,11 @@ class SparseBlockCtx : public SparseCtx {
233256 if (!try_sp_iter_var.defined ()) {
234257 return false ;
235258 }
236- return try_sp_iter_var.value ()->axis == matched_axis.value ();
259+ Axis axis = try_sp_iter_var.value ()->axis ;
260+ if (auto fused_axis = axis.as <FusedAxisNode>()) {
261+ axis = fused_axis->group [fused_axis->index ];
262+ }
263+ return axis == matched_axis.value ();
237264 }
238265
239266 private:
@@ -403,7 +430,11 @@ class IndexTransformer : public StmtExprMutator {
403430 auto try_sp_iter = sp_blk_ctx_.GetSparseIterVar (var);
404431 if (try_sp_iter.defined ()) {
405432 SpIterVar sp_iter = try_sp_iter.value ();
406- return sp_blk_ctx_.GetCoordinate (sp_iter->axis );
433+ Axis axis = sp_iter->axis ;
434+ if (auto fused_axis = axis.as <FusedAxisNode>()) {
435+ axis = fused_axis->group [fused_axis->index ];
436+ }
437+ return sp_blk_ctx_.GetCoordinate (axis);
407438 } else {
408439 return GetRef<PrimExpr>(var);
409440 }
0 commit comments