|
33 | 33 | namespace tvm { |
34 | 34 | namespace tir { |
35 | 35 |
|
36 | | -class SparseTIRLowerer : public StmtExprMutator { |
| 36 | +/*! |
| 37 | + * \brief Check whether a given SparseBuffer contains the given axis. |
| 38 | + * \brief buffer The SparseBuffer to be checked |
| 39 | + * \brief axis The axis to be checked |
| 40 | + * \return A boolean indicating whether the given SparseBuffer contains the given axis |
| 41 | + */ |
| 42 | +bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) { |
| 43 | + for (int i = 0; i < static_cast<int>(buffer->axes.size()); ++i) { |
| 44 | + if (buffer->axes[i].same_as(axis)) { |
| 45 | + return true; |
| 46 | + } |
| 47 | + } |
| 48 | + return false; |
| 49 | +} |
| 50 | + |
| 51 | +using BufferAccessMap = Map<SparseBuffer, Array<SpIterVar>>; |
| 52 | +using DependencyMap = |
| 53 | + std::unordered_map<SpIterVar, std::pair<SparseBuffer, int>, ObjectPtrHash, ObjectPtrEqual>; |
| 54 | + |
| 55 | +/*! |
| 56 | + * \brief For each sparse-fixed or sparse-variable iterator, collect the iterators that it depends |
| 57 | + * on. |
| 58 | + */ |
| 59 | +class AccessAndDependencyCollector : public StmtExprVisitor { |
| 60 | + public: |
| 61 | + void Collect(Stmt stmt) { |
| 62 | + VisitStmt(std::move(stmt)); |
| 63 | + |
| 64 | + for (const std::pair<SparseBuffer, Array<SpIterVar>>& kv_pair : buffer_access_map_) { |
| 65 | + const SparseBuffer& buffer = kv_pair.first; |
| 66 | + int ndim = static_cast<int>(kv_pair.second.size()); |
| 67 | + for (int k = 0; k < ndim; ++k) { |
| 68 | + const SpIterVar& sp_iter = kv_pair.second[k]; |
| 69 | + if (sp_iter->kind == SpIterKind::kDenseFixed || |
| 70 | + sp_iter->kind == SpIterKind::kDenseVariable || |
| 71 | + !BufferContainsAxis(buffer, sp_iter->axis.value())) { |
| 72 | + continue; |
| 73 | + } |
| 74 | + |
| 75 | + ICHECK(dependency_map_.count(sp_iter) == 0); |
| 76 | + dependency_map_[sp_iter] = std::make_pair(buffer, k); |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + BufferAccessMap buffer_access_map_; |
| 82 | + DependencyMap dependency_map_; |
| 83 | + |
| 84 | + private: |
| 85 | + void AddAccessPattern(const SparseBuffer& buffer, const Array<PrimExpr>& indices) { |
| 86 | + int ndim = buffer->ndim(); |
| 87 | + CHECK_EQ(static_cast<int>(indices.size()), ndim); |
| 88 | + |
| 89 | + Array<SpIterVar> iters; |
| 90 | + iters.reserve(ndim); |
| 91 | + for (int i = 0; i < ndim; ++i) { |
| 92 | + const SpIterVarNode* sp_iter = indices[i].as<SpIterVarNode>(); |
| 93 | + CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; |
| 94 | + iters.push_back(GetRef<SpIterVar>(sp_iter)); |
| 95 | + } |
| 96 | + |
| 97 | + BufferAccessMap::iterator it = buffer_access_map_.find(buffer); |
| 98 | + if (it == buffer_access_map_.end()) { |
| 99 | + buffer_access_map_.Set(buffer, iters); |
| 100 | + } else { |
| 101 | + ICHECK_EQ(static_cast<int>((*it).second.size()), ndim); |
| 102 | + for (int i = 0; i < ndim; ++i) { |
| 103 | + CHECK((*it).second[i].same_as(iters[i])) |
| 104 | + << "ValueError: Currently all accesses to a same buffer are required to be the same"; |
| 105 | + } |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + void VisitStmt_(const SparseBufferStoreNode* store) final { |
| 110 | + ExprVisitor::VisitExpr(store->value); |
| 111 | + AddAccessPattern(store->buffer, store->indices); |
| 112 | + } |
| 113 | + |
| 114 | + void VisitExpr_(const SparseBufferLoadNode* load) final { |
| 115 | + AddAccessPattern(load->buffer, load->indices); |
| 116 | + } |
| 117 | +}; |
| 118 | + |
| 119 | +class IndexTransformer : public StmtExprMutator { |
| 120 | + public: |
| 121 | + explicit IndexTransformer(BufferAccessMap buffer_access_map, DependencyMap dependency_map) |
| 122 | + : buffer_access_map_(std::move(buffer_access_map)), |
| 123 | + dependency_map_(std::move(dependency_map)) {} |
| 124 | + |
37 | 125 | private: |
38 | | - std::pair<Buffer, PrimExpr> LowerIndices(SparseBuffer sp_buffer, Array<PrimExpr> indices) { |
| 126 | + PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array<PrimExpr>& indices) { |
39 | 127 | int ndim = sp_buffer->ndim(); |
40 | | - ICHECK_EQ(static_cast<int>(indices.size()), ndim); |
| 128 | + int n_lower = static_cast<int>(indices.size()); |
| 129 | + ICHECK_LE(n_lower, ndim); |
| 130 | + |
41 | 131 | PrimExpr lowered_index = Integer(0); |
42 | 132 |
|
43 | | - for (int i = 0; i < ndim; ++i) { |
| 133 | + for (int i = 0; i < n_lower; ++i) { |
44 | 134 | const Axis& axis = sp_buffer->axes[i]; |
45 | 135 | const PrimExpr& index = indices[i]; |
46 | 136 |
|
47 | | - // Stage 1. |
| 137 | + // Stage 1. Get the sparse index. |
| 138 | + const auto* sp_iter = index.as<SpIterVarNode>(); |
48 | 139 | PrimExpr sp_index{nullptr}; |
49 | | - if (const auto* sp_iter = index.as<SpIterVarNode>()) { |
50 | | - SpIterKind kind = sp_iter->kind; |
51 | | - if (kind == SpIterKind::kDenseFixed) { |
52 | | - CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
53 | | - if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
54 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
55 | | - sp_index = GetRef<SpIterVar>(sp_iter); |
56 | | - } else { |
57 | | - PrimExpr l = LowerIndex(lowered_index, sp_buffer, i, 0); |
58 | | - PrimExpr r = LowerIndex(Add(lowered_index, 1), sp_buffer, i, 0); |
59 | | - Var buffer_var; |
60 | | - if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
61 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
62 | | - buffer_var = sf_axis->indices->data; |
63 | | - } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
64 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
65 | | - buffer_var = sv_axis->indices->data; |
66 | | - } else { |
67 | | - LOG(FATAL) << "Cannot reach here"; |
68 | | - } |
69 | | - sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); |
70 | | - } |
71 | | - } else if (kind == SpIterKind::kDenseVariable) { |
72 | | - const auto* dv_axis = axis.as<DenseVariableAxisNode>(); |
73 | | - CHECK(dv_axis != nullptr); |
74 | | - CHECK(sp_iter->axis.defined()); |
| 140 | + CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; |
| 141 | + |
| 142 | + PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0); |
| 143 | + PrimExpr r = AccumulateLowerIndex(Add(lowered_index, 1), sp_buffer, i, 0); |
| 144 | + |
| 145 | + SpIterKind kind = sp_iter->kind; |
| 146 | + if (kind == SpIterKind::kDenseFixed) { |
| 147 | + CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
| 148 | + if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
| 149 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
75 | 150 | sp_index = GetRef<SpIterVar>(sp_iter); |
76 | | - } else if (kind == SpIterKind::kSparseFixed) { |
77 | | - CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
78 | | - CHECK(sp_iter->axis.defined()); |
79 | | - const Axis& iterated_axis = sp_iter->axis.value(); |
80 | | - if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
81 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
82 | | - // Todo: convert to dense |
83 | | - } else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
84 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
85 | | - if (iterated_axis.get() == sf_axis) { |
86 | | - sp_index = GetRef<SpIterVar>(sp_iter); |
87 | | - } else { |
88 | | - // Todo: convert to dense and do binary search |
89 | | - } |
| 151 | + } else { |
| 152 | + Var buffer_var; |
| 153 | + if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
| 154 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
| 155 | + buffer_var = sf_axis->indices->data; |
90 | 156 | } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
91 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
92 | | - // Todo: convert to dense and do binary search |
| 157 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
| 158 | + buffer_var = sv_axis->indices->data; |
93 | 159 | } else { |
94 | 160 | LOG(FATAL) << "Cannot reach here"; |
95 | 161 | } |
96 | | - } else { |
97 | | - CHECK(kind == SpIterKind::kSparseVariable); |
98 | | - CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
99 | | - CHECK(sp_iter->axis.defined()); |
100 | | - const Axis& iterated_axis = sp_iter->axis.value(); |
101 | | - if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
102 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
103 | | - // Todo: convert to dense |
104 | | - } else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
105 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
106 | | - // Todo: convert to dense and do binary search |
107 | | - } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
108 | | - CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
109 | | - if (iterated_axis.get() == sv_axis) { |
110 | | - sp_index = GetRef<SpIterVar>(sp_iter); |
111 | | - } else { |
112 | | - // Todo: convert to dense and do binary search |
113 | | - } |
| 162 | + sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); |
| 163 | + } |
| 164 | + } else if (kind == SpIterKind::kDenseVariable) { |
| 165 | + const auto* dv_axis = axis.as<DenseVariableAxisNode>(); |
| 166 | + CHECK(dv_axis != nullptr); |
| 167 | + CHECK(sp_iter->axis.defined()); |
| 168 | + sp_index = GetRef<SpIterVar>(sp_iter); |
| 169 | + } else if (kind == SpIterKind::kSparseFixed) { |
| 170 | + CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
| 171 | + CHECK(sp_iter->axis.defined()); |
| 172 | + const Axis& iterated_axis = sp_iter->axis.value(); |
| 173 | + if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
| 174 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
| 175 | + sp_index = GetDenseValue(sp_iter); |
| 176 | + } else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
| 177 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
| 178 | + if (iterated_axis.get() == sf_axis) { |
| 179 | + sp_index = GetRef<SpIterVar>(sp_iter); |
114 | 180 | } else { |
115 | | - LOG(FATAL) << "Cannot reach here"; |
| 181 | + sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), |
| 182 | + std::move(r)); |
116 | 183 | } |
| 184 | + } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
| 185 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
| 186 | + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), |
| 187 | + std::move(r)); |
| 188 | + } else { |
| 189 | + LOG(FATAL) << "Cannot reach here"; |
117 | 190 | } |
118 | 191 | } else { |
119 | | - // Todo |
| 192 | + CHECK(kind == SpIterKind::kSparseVariable); |
| 193 | + CHECK(!axis->IsInstance<DenseVariableAxisNode>()); |
| 194 | + CHECK(sp_iter->axis.defined()); |
| 195 | + const Axis& iterated_axis = sp_iter->axis.value(); |
| 196 | + if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) { |
| 197 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); |
| 198 | + sp_index = GetDenseValue(sp_iter); |
| 199 | + } else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) { |
| 200 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); |
| 201 | + sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), |
| 202 | + std::move(r)); |
| 203 | + } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
| 204 | + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); |
| 205 | + if (iterated_axis.get() == sv_axis) { |
| 206 | + sp_index = GetRef<SpIterVar>(sp_iter); |
| 207 | + } else { |
| 208 | + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), |
| 209 | + std::move(r)); |
| 210 | + } |
| 211 | + } else { |
| 212 | + LOG(FATAL) << "Cannot reach here"; |
| 213 | + } |
120 | 214 | } |
121 | 215 |
|
122 | | - // Stage 2. |
123 | | - lowered_index = LowerIndex(std::move(lowered_index), sp_buffer, i, sp_index); |
| 216 | + // Stage 2. Accumulate the lowered index. |
| 217 | + lowered_index = |
| 218 | + AccumulateLowerIndex(std::move(lowered_index), sp_buffer, i, std::move(sp_index)); |
124 | 219 | } |
125 | 220 |
|
126 | | - return std::make_pair(sp_buffer->data, lowered_index); |
| 221 | + return lowered_index; |
127 | 222 | } |
128 | 223 |
|
129 | | - PrimExpr LowerIndex(PrimExpr prev_lowered_index, SparseBuffer sp_buffer, int dim, |
130 | | - PrimExpr index) { |
| 224 | + PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim, |
| 225 | + PrimExpr index) { |
131 | 226 | const Axis& axis = sp_buffer->axes[dim]; |
132 | 227 | if (axis->IsInstance<DenseFixedAxisNode>() || axis->IsInstance<SparseFixedAxisNode>()) { |
133 | | - return ana.Simplify(prev_lowered_index * axis->length + index); |
| 228 | + return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index)); |
134 | 229 | } else if (const auto* dv_axis = axis.as<DenseVariableAxisNode>()) { |
135 | | - return ana.Simplify(Add(BufferLoad(dv_axis->indptr, {prev_lowered_index}), index)); |
| 230 | + return ana_.Simplify( |
| 231 | + Add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); |
136 | 232 | } else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) { |
137 | | - return ana.Simplify(Add(BufferLoad(sv_axis->indptr, {prev_lowered_index}), index)); |
| 233 | + return ana_.Simplify( |
| 234 | + Add(BufferLoad(sv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); |
138 | 235 | } |
139 | 236 | LOG(FATAL) << "Cannot reach here"; |
140 | 237 | throw; |
141 | 238 | } |
142 | 239 |
|
| 240 | + PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) { |
| 241 | + SpIterKind kind = sp_iter->kind; |
| 242 | + CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); |
| 243 | + Axis iterated_axis = sp_iter->axis.value(); |
| 244 | + |
| 245 | + std::pair<SparseBuffer, int> depended_pair = dependency_map_[GetRef<SpIterVar>(sp_iter)]; |
| 246 | + Array<SpIterVar> buffer_access_iters = buffer_access_map_[depended_pair.first]; |
| 247 | + int n_depended = depended_pair.second; |
| 248 | + |
| 249 | + Array<PrimExpr> depended_iters{buffer_access_iters.begin(), |
| 250 | + buffer_access_iters.begin() + n_depended}; |
| 251 | + PrimExpr lowered_indices = LowerIndices(depended_pair.first, depended_iters); |
| 252 | + |
| 253 | + if (kind == SpIterKind::kSparseFixed) { |
| 254 | + return BufferLoad(Downcast<SparseFixedAxis>(iterated_axis)->indices, |
| 255 | + {std::move(lowered_indices)}); |
| 256 | + } else { |
| 257 | + return BufferLoad(Downcast<SparseVariableAxis>(iterated_axis)->indices, |
| 258 | + {std::move(lowered_indices)}); |
| 259 | + } |
| 260 | + } |
| 261 | + |
143 | 262 | PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { |
144 | | - std::pair<Buffer, PrimExpr> res = LowerIndices(load->buffer, load->indices); |
145 | | - return BufferLoad(std::move(res.first), {std::move(res.second)}); |
| 263 | + PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices); |
| 264 | + return BufferLoad(load->buffer->data, {std::move(lowered_indices)}); |
146 | 265 | } |
147 | 266 |
|
148 | 267 | Stmt VisitStmt_(const SparseBufferStoreNode* store) final { |
149 | 268 | PrimExpr value = ExprMutator::VisitExpr(store->value); |
150 | | - std::pair<Buffer, PrimExpr> res = LowerIndices(store->buffer, store->indices); |
151 | | - return BufferStore(std::move(res.first), std::move(value), {std::move(res.second)}); |
| 269 | + PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices); |
| 270 | + return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)}); |
152 | 271 | } |
153 | 272 |
|
154 | | - arith::Analyzer ana; |
| 273 | + BufferAccessMap buffer_access_map_; |
| 274 | + DependencyMap dependency_map_; |
| 275 | + arith::Analyzer ana_; |
155 | 276 | }; |
156 | 277 |
|
157 | 278 | PrimFunc LowerSparseTIR(PrimFunc f) { |
158 | 279 | // Only apply this pass to TIR that is not from TE schedules |
159 | 280 | if (!IsFromLegacyTESchedule(f)) { |
160 | 281 | PrimFuncNode* fptr = f.CopyOnWrite(); |
161 | | - fptr->body = SparseTIRLowerer()(std::move(f->body)); |
| 282 | + AccessAndDependencyCollector collector; |
| 283 | + collector.Collect(f->body); |
| 284 | + fptr->body = IndexTransformer(collector.buffer_access_map_, |
| 285 | + collector.dependency_map_)(std::move(f->body)); |
162 | 286 | return f; |
163 | 287 | } else { |
164 | 288 | return f; |
|
0 commit comments