@@ -187,25 +187,38 @@ static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
187
187
188
188
// / This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189
189
// / the "swiss army knife" method of the sparse runtime support library
190
- // / for materializing sparse tensors into the computation. This abstraction
191
- // / reduces the need to make modifications to client code whenever that
192
- // / API changes.
190
+ // / for materializing sparse tensors into the computation. This abstraction
191
+ // / reduces the need for modifications when the API changes.
193
192
class NewCallParams final {
194
193
public:
195
- // / Allocates the `ValueRange` for the `func::CallOp` parameters,
196
- // / but does not initialize them.
194
+ // / Allocates the `ValueRange` for the `func::CallOp` parameters.
197
195
NewCallParams (OpBuilder &builder, Location loc)
198
196
: builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
199
197
200
198
// / Initializes all static parameters (i.e., those which indicate
201
199
// / type-level information such as the encoding and sizes), generating
202
200
// / MLIR buffers as needed, and returning `this` for method chaining.
203
- // / This method does not set the action and pointer arguments, since
204
- // / those are handled by `genNewCall` instead.
205
- NewCallParams &genBuffers (SparseTensorType stt, ValueRange dimSizes);
201
+ NewCallParams &genBuffers (SparseTensorType stt,
202
+ ArrayRef<Value> dimSizesValues) {
203
+ const Dimension dimRank = stt.getDimRank ();
204
+ assert (dimSizesValues.size () == static_cast <size_t >(dimRank));
205
+ // Sparsity annotations.
206
+ params[kParamLvlTypes ] = genLvlTypesBuffer (builder, loc, stt);
207
+ // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
208
+ params[kParamDimSizes ] = allocaBuffer (builder, loc, dimSizesValues);
209
+ params[kParamLvlSizes ] = genReaderBuffers (
210
+ builder, loc, stt, dimSizesValues, params[kParamDimSizes ],
211
+ params[kParamDim2Lvl ], params[kParamLvl2Dim ]);
212
+ // Secondary and primary types encoding.
213
+ setTemplateTypes (stt);
214
+ // Finally, make note that initialization is complete.
215
+ assert (isInitialized () && " Initialization failed" );
216
+ // And return `this` for method chaining.
217
+ return *this ;
218
+ }
206
219
207
220
// / (Re)sets the C++ template type parameters, and returns `this`
208
- // / for method chaining. This is already done as part of `genBuffers`,
221
+ // / for method chaining. This is already done as part of `genBuffers`,
209
222
// / but is factored out so that it can also be called independently
210
223
// / whenever subsequent `genNewCall` calls want to reuse the same
211
224
// / buffers but different type parameters.
@@ -236,7 +249,7 @@ class NewCallParams final {
236
249
// this one-off getter, and to avoid potential mixups)?
237
250
Value getDimToLvl () const {
238
251
assert (isInitialized () && " Must initialize before getDimToLvl" );
239
- return params[kParamDimToLvl ];
252
+ return params[kParamDim2Lvl ];
240
253
}
241
254
242
255
// / Generates a function call, with the current static parameters
@@ -257,8 +270,8 @@ class NewCallParams final {
257
270
static constexpr unsigned kParamDimSizes = 0 ;
258
271
static constexpr unsigned kParamLvlSizes = 1 ;
259
272
static constexpr unsigned kParamLvlTypes = 2 ;
260
- static constexpr unsigned kParamLvlToDim = 3 ;
261
- static constexpr unsigned kParamDimToLvl = 4 ;
273
+ static constexpr unsigned kParamDim2Lvl = 3 ;
274
+ static constexpr unsigned kParamLvl2Dim = 4 ;
262
275
static constexpr unsigned kParamPosTp = 5 ;
263
276
static constexpr unsigned kParamCrdTp = 6 ;
264
277
static constexpr unsigned kParamValTp = 7 ;
@@ -271,62 +284,6 @@ class NewCallParams final {
271
284
Value params[kNumParams ];
272
285
};
273
286
274
- // TODO: see the note at `_mlir_ciface_newSparseTensor` about how
275
- // the meaning of the various arguments (e.g., "sizes" vs "shapes")
276
- // is inconsistent between the different actions.
277
- NewCallParams &NewCallParams::genBuffers (SparseTensorType stt,
278
- ValueRange dimSizes) {
279
- const Level lvlRank = stt.getLvlRank ();
280
- const Dimension dimRank = stt.getDimRank ();
281
- // Sparsity annotations.
282
- params[kParamLvlTypes ] = genLvlTypesBuffer (builder, loc, stt);
283
- // Dimension-sizes array of the enveloping tensor. Useful for either
284
- // verification of external data, or for construction of internal data.
285
- assert (dimSizes.size () == static_cast <size_t >(dimRank) &&
286
- " Dimension-rank mismatch" );
287
- params[kParamDimSizes ] = allocaBuffer (builder, loc, dimSizes);
288
- // The level-sizes array must be passed as well, since for arbitrary
289
- // dimToLvl mappings it cannot be trivially reconstructed at runtime.
290
- // For now however, since we're still assuming permutations, we will
291
- // initialize this parameter alongside the `dimToLvl` and `lvlToDim`
292
- // parameters below. We preinitialize `lvlSizes` for code symmetry.
293
- SmallVector<Value> lvlSizes (lvlRank);
294
- // The dimension-to-level mapping and its inverse. We must preinitialize
295
- // `dimToLvl` so that the true branch below can perform random-access
296
- // `operator[]` assignment. We preinitialize `lvlToDim` for code symmetry.
297
- SmallVector<Value> dimToLvl (dimRank);
298
- SmallVector<Value> lvlToDim (lvlRank);
299
- if (!stt.isIdentity ()) {
300
- const auto dimToLvlMap = stt.getDimToLvl ();
301
- assert (dimToLvlMap.isPermutation ());
302
- for (Level l = 0 ; l < lvlRank; l++) {
303
- // The `d`th source variable occurs in the `l`th result position.
304
- const Dimension d = dimToLvlMap.getDimPosition (l);
305
- dimToLvl[d] = constantIndex (builder, loc, l);
306
- lvlToDim[l] = constantIndex (builder, loc, d);
307
- lvlSizes[l] = dimSizes[d];
308
- }
309
- } else {
310
- // The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
311
- // when `isIdentity`; so no need to re-assert it here.
312
- for (Level l = 0 ; l < lvlRank; l++) {
313
- dimToLvl[l] = lvlToDim[l] = constantIndex (builder, loc, l);
314
- lvlSizes[l] = dimSizes[l];
315
- }
316
- }
317
- params[kParamLvlSizes ] = allocaBuffer (builder, loc, lvlSizes);
318
- params[kParamLvlToDim ] = allocaBuffer (builder, loc, lvlToDim);
319
- params[kParamDimToLvl ] = stt.isIdentity ()
320
- ? params[kParamLvlToDim ]
321
- : allocaBuffer (builder, loc, dimToLvl);
322
- // Secondary and primary types encoding.
323
- setTemplateTypes (stt);
324
- // Finally, make note that initialization is complete.
325
- assert (isInitialized () && " Initialization failed" );
326
- // And return `this` for method chaining.
327
- return *this ;
328
- }
329
-
330
287
// / Generates a call to obtain the values array.
331
288
static Value genValuesCall (OpBuilder &builder, Location loc, ShapedType tp,
332
289
ValueRange ptr) {
0 commit comments