Skip to content

Commit 2b542f5

Browse files
committed
Fix sparse reorder after refactor (#47)
1 parent 3ecc489 commit 2b542f5

File tree

3 files changed

+89
-252
lines changed

3 files changed

+89
-252
lines changed

src/tir/schedule/primitive/sparse_loop_transformation.cc

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,58 @@ void CheckValidInputIterators(const ScheduleState self, const Array<SpIterVar>&
9494
}
9595
}
9696

97+
/*!
98+
* \brief Check whether the sparse reorder would break dependency between iterators.
99+
* \param new_order The new iterator order to be checked.
100+
* \throw ScheduleError If the sparse reorder breaks dependency.
101+
*/
102+
void CheckDependency(const ScheduleState self, const Array<SpIterVar>& new_order) {
103+
class DependencyError : public ScheduleError {
104+
public:
105+
explicit DependencyError(IRModule mod, SpIterVar iter, Array<SpIterVar> new_order):
106+
mod_(std::move(mod)), iter_(std::move(iter)), new_order_(std::move(new_order)) {}
107+
108+
String FastErrorString() const final {
109+
return "ScheduleError: the sparse reorder breaks dependency between axes.";
110+
}
111+
112+
String DetailRenderTemplate() const final {
113+
std::ostringstream os;
114+
os << "ScheduleError: in new order " << new_order_
115+
<< " iterator " << iter_ << " was placed before its dependent iterator.";
116+
return os.str();
117+
}
118+
119+
IRModule mod() const final { return mod_; }
120+
Array<ObjectRef> LocationsOfInterest() const final { return {}; }
121+
122+
IRModule mod_;
123+
SpIterVar iter_;
124+
Array<SpIterVar> new_order_;
125+
};
126+
127+
std::set<Axis> axes_set;
128+
for (const SpIterVar& sp_iter : new_order) {
129+
Axis axis = sp_iter->axis;
130+
auto try_parent = axis->GetParentAxis();
131+
if (try_parent.defined()) {
132+
Axis parent = try_parent.value();
133+
if (axes_set.find(parent) == axes_set.end()) {
134+
throw DependencyError(self->mod, sp_iter, new_order);
135+
}
136+
}
137+
axes_set.insert(axis);
138+
}
139+
}
140+
141+
97142
SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block,
98143
const Array<SpIterVar>& new_order) {
99144
// Step 1. Check whether the iterators in `new_order` are the same as `block`'s iterators.
100145
CheckValidInputIterators(self, new_order, block->sp_iter_vars);
101146

102147
// Step 2. Check whether the new order does not break the iterator dependency.
103-
// TODO(zihao): rewrite this part.
104-
// CheckDependency(self, block, new_order);
148+
CheckDependency(self, new_order);
105149

106150
// Step 3. Create the new SparseBlock.
107151
ObjectPtr<SparseBlockNode> p_new_block = make_object<SparseBlockNode>(*block.get());

tests/python/sparsetir/test_tir_rgcn.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,16 @@ def rgcn(
8282
F_in = T.dense_fixed(feat_size)
8383
F_out = T.dense_fixed(feat_size)
8484
E = T.match_sparse_buffer(etype, (I, J), "int32")
85-
W = T.match_sparse_buffer(w, (R, F_in, F_out), "float32")
85+
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
8686
X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32")
8787
Y = T.match_sparse_buffer(y, (I, F_out), "float32")
88+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
8889
with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [
8990
vi, vout, vj, vin,
9091
]:
9192
with T.init():
9293
Y[vi, vout] = 0.
93-
Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vin, vout] * X[vj, vin]
94+
Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin]
9495

9596

9697
@T.prim_func
@@ -179,15 +180,19 @@ def msg_func(edges):
179180
print("dgl high-mem:\t\t", accum / (total - cold_start))
180181

181182
# tir
182-
N, R, FEAT_SIZE, NNZ = lowered_rgcn.params[-4:]
183+
mod = tvm.IRModule.from_expr(rgcn)
184+
mod = tvm.tir.transform.LowerSparseTIR()(mod)
185+
tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn, True)
186+
187+
N, R, FEAT_SIZE, NNZ = mod["main"].params[-4:]
183188
sch = tir.Schedule(
184-
lowered_rgcn.specialize(
189+
mod["main"].specialize(
185190
{N: g.number_of_nodes(), R: g.num_rels, FEAT_SIZE: feat_size, NNZ: g.number_of_edges()}
186191
)
187192
)
188193

189-
outer = sch.get_block("rgcn-forward_0")
190-
inner = sch.get_block("rgcn-forward_1")
194+
outer = sch.get_block("rgcn-forward0")
195+
inner = sch.get_block("rgcn-forward1")
191196
i, f_out = sch.get_loops(outer)
192197
j, f_in = sch.get_loops(inner)
193198
sch.bind(i, "blockIdx.x")

0 commit comments

Comments
 (0)