Skip to content

Commit e4aa7e3

Browse files
committed
Finish index lowering (maybe)
1 parent 43db26b commit e4aa7e3

File tree

2 files changed

+219
-82
lines changed

2 files changed

+219
-82
lines changed

src/tir/ir/sparse.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* \file sparse.cc
2222
* \brief buffers and formats in sparse tir.
2323
*/
24+
#include <tvm/arith/analyzer.h>
2425
#include <tvm/runtime/registry.h>
2526
#include <tvm/tir/buffer.h>
2627
#include <tvm/tir/sparse.h>
@@ -158,15 +159,27 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
158159
Optional<Axis> axis) {
159160
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
160161

162+
arith::Analyzer ana;
163+
if (axis.defined()) {
164+
CHECK(ana.CanProveEqual(axis.value()->length, max_extent));
165+
}
161166
if (kind != SpIterKind::kDenseFixed) {
162167
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
163168
"specify the axis over which the SpIterVar iterates";
169+
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
170+
if (kind == SpIterKind::kDenseVariable) {
171+
CHECK(axis.value()->IsInstance<DenseFixedAxisNode>()) << err_str;
172+
} else if (kind == SpIterKind::kSparseFixed) {
173+
CHECK(axis.value()->IsInstance<SparseFixedAxisNode>()) << err_str;
174+
} else if (kind == SpIterKind::kSparseVariable) {
175+
CHECK(axis.value()->IsInstance<SparseVariableAxisNode>()) << err_str;
176+
}
164177
}
165178

166179
node->var = Var(std::move(name));
167180
node->max_extent = std::move(max_extent);
168181
node->kind = kind;
169-
node->is_reduction = is_reduction;
182+
node->is_reduction = is_reduction;
170183
node->axis = std::move(axis);
171184
data_ = std::move(node);
172185
}

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 205 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -33,132 +33,256 @@
3333
namespace tvm {
3434
namespace tir {
3535

36-
class SparseTIRLowerer : public StmtExprMutator {
36+
/*!
37+
* \brief Check whether a given SparseBuffer contains the given axis.
38+
* \brief buffer The SparseBuffer to be checked
39+
* \brief axis The axis to be checked
40+
* \return A boolean indicating whether the given SparseBuffer contains the given axis
41+
*/
42+
bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) {
43+
for (int i = 0; i < static_cast<int>(buffer->axes.size()); ++i) {
44+
if (buffer->axes[i].same_as(axis)) {
45+
return true;
46+
}
47+
}
48+
return false;
49+
}
50+
51+
using BufferAccessMap = Map<SparseBuffer, Array<SpIterVar>>;
52+
using DependencyMap =
53+
std::unordered_map<SpIterVar, std::pair<SparseBuffer, int>, ObjectPtrHash, ObjectPtrEqual>;
54+
55+
/*
56+
* \brief For each sparse-fixed or sparse-variable iterator, collect the iterators that it depends
57+
* on.
58+
*/
59+
class AccessAndDependencyCollector : public StmtExprVisitor {
60+
public:
61+
void Collect(Stmt stmt) {
62+
VisitStmt(std::move(stmt));
63+
64+
for (const std::pair<SparseBuffer, Array<SpIterVar>>& kv_pair : buffer_access_map_) {
65+
const SparseBuffer& buffer = kv_pair.first;
66+
int ndim = static_cast<int>(kv_pair.second.size());
67+
for (int k = 0; k < ndim; ++k) {
68+
const SpIterVar& sp_iter = kv_pair.second[k];
69+
if (sp_iter->kind == SpIterKind::kDenseFixed ||
70+
sp_iter->kind == SpIterKind::kDenseVariable ||
71+
!BufferContainsAxis(buffer, sp_iter->axis.value())) {
72+
continue;
73+
}
74+
75+
ICHECK(dependency_map_.count(sp_iter) == 0);
76+
dependency_map_[sp_iter] = std::make_pair(buffer, k);
77+
}
78+
}
79+
}
80+
81+
BufferAccessMap buffer_access_map_;
82+
DependencyMap dependency_map_;
83+
84+
private:
85+
void AddAccessPattern(const SparseBuffer& buffer, const Array<PrimExpr>& indices) {
86+
int ndim = buffer->ndim();
87+
CHECK_EQ(static_cast<int>(indices.size()), ndim);
88+
89+
Array<SpIterVar> iters;
90+
iters.reserve(ndim);
91+
for (int i = 0; i < ndim; ++i) {
92+
const SpIterVarNode* sp_iter = indices[i].as<SpIterVarNode>();
93+
CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar";
94+
iters.push_back(GetRef<SpIterVar>(sp_iter));
95+
}
96+
97+
BufferAccessMap::iterator it = buffer_access_map_.find(buffer);
98+
if (it == buffer_access_map_.end()) {
99+
buffer_access_map_.Set(buffer, iters);
100+
} else {
101+
ICHECK_EQ(static_cast<int>((*it).second.size()), ndim);
102+
for (int i = 0; i < ndim; ++i) {
103+
CHECK((*it).second[i].same_as(iters[i]))
104+
<< "ValueError: Currently all accesses to a same buffer are required to be the same";
105+
}
106+
}
107+
}
108+
109+
void VisitStmt_(const SparseBufferStoreNode* store) final {
110+
ExprVisitor::VisitExpr(store->value);
111+
AddAccessPattern(store->buffer, store->indices);
112+
}
113+
114+
void VisitExpr_(const SparseBufferLoadNode* load) final {
115+
AddAccessPattern(load->buffer, load->indices);
116+
}
117+
};
118+
119+
class IndexTransformer : public StmtExprMutator {
120+
public:
121+
explicit IndexTransformer(BufferAccessMap buffer_access_map, DependencyMap dependency_map)
122+
: buffer_access_map_(std::move(buffer_access_map)),
123+
dependency_map_(std::move(dependency_map)) {}
124+
37125
private:
38-
std::pair<Buffer, PrimExpr> LowerIndices(SparseBuffer sp_buffer, Array<PrimExpr> indices) {
126+
PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array<PrimExpr>& indices) {
39127
int ndim = sp_buffer->ndim();
40-
ICHECK_EQ(static_cast<int>(indices.size()), ndim);
128+
int n_lower = static_cast<int>(indices.size());
129+
ICHECK_LE(n_lower, ndim);
130+
41131
PrimExpr lowered_index = Integer(0);
42132

43-
for (int i = 0; i < ndim; ++i) {
133+
for (int i = 0; i < n_lower; ++i) {
44134
const Axis& axis = sp_buffer->axes[i];
45135
const PrimExpr& index = indices[i];
46136

47-
// Stage 1.
137+
// Stage 1. Get the sparse index.
138+
const auto* sp_iter = index.as<SpIterVarNode>();
48139
PrimExpr sp_index{nullptr};
49-
if (const auto* sp_iter = index.as<SpIterVarNode>()) {
50-
SpIterKind kind = sp_iter->kind;
51-
if (kind == SpIterKind::kDenseFixed) {
52-
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
53-
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
54-
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
55-
sp_index = GetRef<SpIterVar>(sp_iter);
56-
} else {
57-
PrimExpr l = LowerIndex(lowered_index, sp_buffer, i, 0);
58-
PrimExpr r = LowerIndex(Add(lowered_index, 1), sp_buffer, i, 0);
59-
Var buffer_var;
60-
if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
61-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
62-
buffer_var = sf_axis->indices->data;
63-
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
64-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
65-
buffer_var = sv_axis->indices->data;
66-
} else {
67-
LOG(FATAL) << "Cannot reach here";
68-
}
69-
sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r));
70-
}
71-
} else if (kind == SpIterKind::kDenseVariable) {
72-
const auto* dv_axis = axis.as<DenseVariableAxisNode>();
73-
CHECK(dv_axis != nullptr);
74-
CHECK(sp_iter->axis.defined());
140+
CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar";
141+
142+
PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0);
143+
PrimExpr r = AccumulateLowerIndex(Add(lowered_index, 1), sp_buffer, i, 0);
144+
145+
SpIterKind kind = sp_iter->kind;
146+
if (kind == SpIterKind::kDenseFixed) {
147+
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
148+
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
149+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
75150
sp_index = GetRef<SpIterVar>(sp_iter);
76-
} else if (kind == SpIterKind::kSparseFixed) {
77-
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
78-
CHECK(sp_iter->axis.defined());
79-
const Axis& iterated_axis = sp_iter->axis.value();
80-
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
81-
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
82-
// Todo: convert to dense
83-
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
84-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
85-
if (iterated_axis.get() == sf_axis) {
86-
sp_index = GetRef<SpIterVar>(sp_iter);
87-
} else {
88-
// Todo: convert to dense and do binary search
89-
}
151+
} else {
152+
Var buffer_var;
153+
if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
154+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
155+
buffer_var = sf_axis->indices->data;
90156
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
91-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
92-
// Todo: convert to dense and do binary search
157+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
158+
buffer_var = sv_axis->indices->data;
93159
} else {
94160
LOG(FATAL) << "Cannot reach here";
95161
}
96-
} else {
97-
CHECK(kind == SpIterKind::kSparseVariable);
98-
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
99-
CHECK(sp_iter->axis.defined());
100-
const Axis& iterated_axis = sp_iter->axis.value();
101-
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
102-
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
103-
// Todo: convert to dense
104-
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
105-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
106-
// Todo: convert to dense and do binary search
107-
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
108-
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
109-
if (iterated_axis.get() == sv_axis) {
110-
sp_index = GetRef<SpIterVar>(sp_iter);
111-
} else {
112-
// Todo: convert to dense and do binary search
113-
}
162+
sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r));
163+
}
164+
} else if (kind == SpIterKind::kDenseVariable) {
165+
const auto* dv_axis = axis.as<DenseVariableAxisNode>();
166+
CHECK(dv_axis != nullptr);
167+
CHECK(sp_iter->axis.defined());
168+
sp_index = GetRef<SpIterVar>(sp_iter);
169+
} else if (kind == SpIterKind::kSparseFixed) {
170+
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
171+
CHECK(sp_iter->axis.defined());
172+
const Axis& iterated_axis = sp_iter->axis.value();
173+
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
174+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
175+
sp_index = GetDenseValue(sp_iter);
176+
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
177+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
178+
if (iterated_axis.get() == sf_axis) {
179+
sp_index = GetRef<SpIterVar>(sp_iter);
114180
} else {
115-
LOG(FATAL) << "Cannot reach here";
181+
sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
182+
std::move(r));
116183
}
184+
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
185+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
186+
sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
187+
std::move(r));
188+
} else {
189+
LOG(FATAL) << "Cannot reach here";
117190
}
118191
} else {
119-
// Todo
192+
CHECK(kind == SpIterKind::kSparseVariable);
193+
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
194+
CHECK(sp_iter->axis.defined());
195+
const Axis& iterated_axis = sp_iter->axis.value();
196+
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
197+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
198+
sp_index = GetDenseValue(sp_iter);
199+
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
200+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
201+
sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
202+
std::move(r));
203+
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
204+
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
205+
if (iterated_axis.get() == sv_axis) {
206+
sp_index = GetRef<SpIterVar>(sp_iter);
207+
} else {
208+
sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
209+
std::move(r));
210+
}
211+
} else {
212+
LOG(FATAL) << "Cannot reach here";
213+
}
120214
}
121215

122-
// Stage 2.
123-
lowered_index = LowerIndex(std::move(lowered_index), sp_buffer, i, sp_index);
216+
// Stage 2. Accumulate the lowered index.
217+
lowered_index =
218+
AccumulateLowerIndex(std::move(lowered_index), sp_buffer, i, std::move(sp_index));
124219
}
125220

126-
return std::make_pair(sp_buffer->data, lowered_index);
221+
return lowered_index;
127222
}
128223

129-
PrimExpr LowerIndex(PrimExpr prev_lowered_index, SparseBuffer sp_buffer, int dim,
130-
PrimExpr index) {
224+
PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim,
225+
PrimExpr index) {
131226
const Axis& axis = sp_buffer->axes[dim];
132227
if (axis->IsInstance<DenseFixedAxisNode>() || axis->IsInstance<SparseFixedAxisNode>()) {
133-
return ana.Simplify(prev_lowered_index * axis->length + index);
228+
return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index));
134229
} else if (const auto* dv_axis = axis.as<DenseVariableAxisNode>()) {
135-
return ana.Simplify(Add(BufferLoad(dv_axis->indptr, {prev_lowered_index}), index));
230+
return ana_.Simplify(
231+
Add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index)));
136232
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
137-
return ana.Simplify(Add(BufferLoad(sv_axis->indptr, {prev_lowered_index}), index));
233+
return ana_.Simplify(
234+
Add(BufferLoad(sv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index)));
138235
}
139236
LOG(FATAL) << "Cannot reach here";
140237
throw;
141238
}
142239

240+
PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) {
241+
SpIterKind kind = sp_iter->kind;
242+
CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable);
243+
Axis iterated_axis = sp_iter->axis.value();
244+
245+
std::pair<SparseBuffer, int> depended_pair = dependency_map_[GetRef<SpIterVar>(sp_iter)];
246+
Array<SpIterVar> buffer_access_iters = buffer_access_map_[depended_pair.first];
247+
int n_depended = depended_pair.second;
248+
249+
Array<PrimExpr> depended_iters{buffer_access_iters.begin(),
250+
buffer_access_iters.begin() + n_depended};
251+
PrimExpr lowered_indices = LowerIndices(depended_pair.first, depended_iters);
252+
253+
if (kind == SpIterKind::kSparseFixed) {
254+
return BufferLoad(Downcast<SparseFixedAxis>(iterated_axis)->indices,
255+
{std::move(lowered_indices)});
256+
} else {
257+
return BufferLoad(Downcast<SparseVariableAxis>(iterated_axis)->indices,
258+
{std::move(lowered_indices)});
259+
}
260+
}
261+
143262
PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final {
144-
std::pair<Buffer, PrimExpr> res = LowerIndices(load->buffer, load->indices);
145-
return BufferLoad(std::move(res.first), {std::move(res.second)});
263+
PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices);
264+
return BufferLoad(load->buffer->data, {std::move(lowered_indices)});
146265
}
147266

148267
Stmt VisitStmt_(const SparseBufferStoreNode* store) final {
149268
PrimExpr value = ExprMutator::VisitExpr(store->value);
150-
std::pair<Buffer, PrimExpr> res = LowerIndices(store->buffer, store->indices);
151-
return BufferStore(std::move(res.first), std::move(value), {std::move(res.second)});
269+
PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices);
270+
return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)});
152271
}
153272

154-
arith::Analyzer ana;
273+
BufferAccessMap buffer_access_map_;
274+
DependencyMap dependency_map_;
275+
arith::Analyzer ana_;
155276
};
156277

157278
PrimFunc LowerSparseTIR(PrimFunc f) {
158279
// Only apply this pass to TIR that is not from TE schedules
159280
if (!IsFromLegacyTESchedule(f)) {
160281
PrimFuncNode* fptr = f.CopyOnWrite();
161-
fptr->body = SparseTIRLowerer()(std::move(f->body));
282+
AccessAndDependencyCollector collector;
283+
collector.Collect(f->body);
284+
fptr->body = IndexTransformer(collector.buffer_access_map_,
285+
collector.dependency_map_)(std::move(f->body));
162286
return f;
163287
} else {
164288
return f;

0 commit comments

Comments
 (0)