Skip to content

Commit b7188d2

Browse files
authored
[mlir][sparse] replace specialized buffer setup with util code (#68461)
This completely centralizes all set up related to dim2lvl and lvl2dim for the runtime library (and even parts of direct IR codegen) into one place! And all comptatible with the MapRef data structure that should be used in all remaining clients of dim2lvl and lvl2dim. NOTE: the convert_x2y.mlir tests were becoming too overloaded so I decided to bring them back to the basics; if e.g. more coverage of the foreach is required, they should go into isolated smalle tests
1 parent ea86fb8 commit b7188d2

File tree

8 files changed

+820
-799
lines changed

8 files changed

+820
-799
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
5757
StridedMemRefType<index_type, 1> *dimSizesRef,
5858
StridedMemRefType<index_type, 1> *lvlSizesRef,
5959
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
60-
StridedMemRefType<index_type, 1> *lvl2dimRef,
61-
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
60+
StridedMemRefType<index_type, 1> *dim2lvlRef,
61+
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
6262
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr);
6363

6464
/// Tensor-storage method to obtain direct access to the values array.
@@ -85,6 +85,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
8585
#undef DECL_SPARSECOORDINATES
8686

8787
/// Coordinate-scheme method for adding a new element.
88+
/// TODO: remove dim2lvl
8889
#define DECL_ADDELT(VNAME, V) \
8990
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME( \
9091
void *lvlCOO, StridedMemRefType<V, 0> *vref, \

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 25 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -187,25 +187,38 @@ static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,
187187

188188
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
189189
/// the "swiss army knife" method of the sparse runtime support library
190-
/// for materializing sparse tensors into the computation. This abstraction
191-
/// reduces the need to make modifications to client code whenever that
192-
/// API changes.
190+
/// for materializing sparse tensors into the computation. This abstraction
191+
/// reduces the need for modifications when the API changes.
193192
class NewCallParams final {
194193
public:
195-
/// Allocates the `ValueRange` for the `func::CallOp` parameters,
196-
/// but does not initialize them.
194+
/// Allocates the `ValueRange` for the `func::CallOp` parameters.
197195
NewCallParams(OpBuilder &builder, Location loc)
198196
: builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
199197

200198
/// Initializes all static parameters (i.e., those which indicate
201199
/// type-level information such as the encoding and sizes), generating
202200
/// MLIR buffers as needed, and returning `this` for method chaining.
203-
/// This method does not set the action and pointer arguments, since
204-
/// those are handled by `genNewCall` instead.
205-
NewCallParams &genBuffers(SparseTensorType stt, ValueRange dimSizes);
201+
NewCallParams &genBuffers(SparseTensorType stt,
202+
ArrayRef<Value> dimSizesValues) {
203+
const Dimension dimRank = stt.getDimRank();
204+
assert(dimSizesValues.size() == static_cast<size_t>(dimRank));
205+
// Sparsity annotations.
206+
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
207+
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
208+
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
209+
params[kParamLvlSizes] = genReaderBuffers(
210+
builder, loc, stt, dimSizesValues, params[kParamDimSizes],
211+
params[kParamDim2Lvl], params[kParamLvl2Dim]);
212+
// Secondary and primary types encoding.
213+
setTemplateTypes(stt);
214+
// Finally, make note that initialization is complete.
215+
assert(isInitialized() && "Initialization failed");
216+
// And return `this` for method chaining.
217+
return *this;
218+
}
206219

207220
/// (Re)sets the C++ template type parameters, and returns `this`
208-
/// for method chaining. This is already done as part of `genBuffers`,
221+
/// for method chaining. This is already done as part of `genBuffers`,
209222
/// but is factored out so that it can also be called independently
210223
/// whenever subsequent `genNewCall` calls want to reuse the same
211224
/// buffers but different type parameters.
@@ -236,7 +249,7 @@ class NewCallParams final {
236249
// this one-off getter, and to avoid potential mixups)?
237250
Value getDimToLvl() const {
238251
assert(isInitialized() && "Must initialize before getDimToLvl");
239-
return params[kParamDimToLvl];
252+
return params[kParamDim2Lvl];
240253
}
241254

242255
/// Generates a function call, with the current static parameters
@@ -257,8 +270,8 @@ class NewCallParams final {
257270
static constexpr unsigned kParamDimSizes = 0;
258271
static constexpr unsigned kParamLvlSizes = 1;
259272
static constexpr unsigned kParamLvlTypes = 2;
260-
static constexpr unsigned kParamLvlToDim = 3;
261-
static constexpr unsigned kParamDimToLvl = 4;
273+
static constexpr unsigned kParamDim2Lvl = 3;
274+
static constexpr unsigned kParamLvl2Dim = 4;
262275
static constexpr unsigned kParamPosTp = 5;
263276
static constexpr unsigned kParamCrdTp = 6;
264277
static constexpr unsigned kParamValTp = 7;
@@ -271,62 +284,6 @@ class NewCallParams final {
271284
Value params[kNumParams];
272285
};
273286

274-
// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
275-
// the meaning of the various arguments (e.g., "sizes" vs "shapes")
276-
// is inconsistent between the different actions.
277-
NewCallParams &NewCallParams::genBuffers(SparseTensorType stt,
278-
ValueRange dimSizes) {
279-
const Level lvlRank = stt.getLvlRank();
280-
const Dimension dimRank = stt.getDimRank();
281-
// Sparsity annotations.
282-
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
283-
// Dimension-sizes array of the enveloping tensor. Useful for either
284-
// verification of external data, or for construction of internal data.
285-
assert(dimSizes.size() == static_cast<size_t>(dimRank) &&
286-
"Dimension-rank mismatch");
287-
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes);
288-
// The level-sizes array must be passed as well, since for arbitrary
289-
// dimToLvl mappings it cannot be trivially reconstructed at runtime.
290-
// For now however, since we're still assuming permutations, we will
291-
// initialize this parameter alongside the `dimToLvl` and `lvlToDim`
292-
// parameters below. We preinitialize `lvlSizes` for code symmetry.
293-
SmallVector<Value> lvlSizes(lvlRank);
294-
// The dimension-to-level mapping and its inverse. We must preinitialize
295-
// `dimToLvl` so that the true branch below can perform random-access
296-
// `operator[]` assignment. We preinitialize `lvlToDim` for code symmetry.
297-
SmallVector<Value> dimToLvl(dimRank);
298-
SmallVector<Value> lvlToDim(lvlRank);
299-
if (!stt.isIdentity()) {
300-
const auto dimToLvlMap = stt.getDimToLvl();
301-
assert(dimToLvlMap.isPermutation());
302-
for (Level l = 0; l < lvlRank; l++) {
303-
// The `d`th source variable occurs in the `l`th result position.
304-
const Dimension d = dimToLvlMap.getDimPosition(l);
305-
dimToLvl[d] = constantIndex(builder, loc, l);
306-
lvlToDim[l] = constantIndex(builder, loc, d);
307-
lvlSizes[l] = dimSizes[d];
308-
}
309-
} else {
310-
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
311-
// when `isIdentity`; so no need to re-assert it here.
312-
for (Level l = 0; l < lvlRank; l++) {
313-
dimToLvl[l] = lvlToDim[l] = constantIndex(builder, loc, l);
314-
lvlSizes[l] = dimSizes[l];
315-
}
316-
}
317-
params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes);
318-
params[kParamLvlToDim] = allocaBuffer(builder, loc, lvlToDim);
319-
params[kParamDimToLvl] = stt.isIdentity()
320-
? params[kParamLvlToDim]
321-
: allocaBuffer(builder, loc, dimToLvl);
322-
// Secondary and primary types encoding.
323-
setTemplateTypes(stt);
324-
// Finally, make note that initialization is complete.
325-
assert(isInitialized() && "Initialization failed");
326-
// And return `this` for method chaining.
327-
return *this;
328-
}
329-
330287
/// Generates a call to obtain the values array.
331288
static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
332289
ValueRange ptr) {

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
231231
StridedMemRefType<index_type, 1> *dimSizesRef,
232232
StridedMemRefType<index_type, 1> *lvlSizesRef,
233233
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
234-
StridedMemRefType<index_type, 1> *lvl2dimRef,
235-
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
234+
StridedMemRefType<index_type, 1> *dim2lvlRef,
235+
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
236236
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
237237
ASSERT_NO_STRIDE(dimSizesRef);
238238
ASSERT_NO_STRIDE(lvlSizesRef);
@@ -250,6 +250,9 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
250250
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
251251
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
252252

253+
// Prepare map.
254+
// TODO: start using MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim) below
255+
253256
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
254257
// This is safe because of the static_assert above.
255258
if (posTp == OverheadType::kIndex)
@@ -400,6 +403,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
400403
#undef IMPL_GETOVERHEAD
401404

402405
// TODO: use MapRef here for translation of coordinates
406+
// TOOD: remove dim2lvl
403407
#define IMPL_ADDELT(VNAME, V) \
404408
void *_mlir_ciface_addElt##VNAME( \
405409
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
@@ -540,13 +544,13 @@ void *_mlir_ciface_newSparseTensorFromReader(
540544
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
541545
ASSERT_NO_STRIDE(lvlSizesRef);
542546
ASSERT_NO_STRIDE(lvlTypesRef);
543-
ASSERT_NO_STRIDE(lvl2dimRef);
544547
ASSERT_NO_STRIDE(dim2lvlRef);
548+
ASSERT_NO_STRIDE(lvl2dimRef);
545549
const uint64_t dimRank = reader.getRank();
546550
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
547551
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
548-
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
549552
ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
553+
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
550554
(void)dimRank;
551555
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
552556
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,16 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
136136
// CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
137137
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
138138
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
139-
// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<2xindex>
140-
// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<2xindex>
141139
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
140+
// CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
142141
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
143-
// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<2xindex> to memref<?xindex>
144-
// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<2xindex> to memref<?xindex>
145142
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
143+
// CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
146144
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
147-
// CHECK-DAG: memref.store %[[I]], %[[DimSizes0]][%[[C0]]] : memref<2xindex>
148-
// CHECK-DAG: memref.store %[[J]], %[[DimSizes0]][%[[C1]]] : memref<2xindex>
145+
// CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
146+
// CHECK-DAG: memref.store %[[J]], %[[Sizes0]][%[[C1]]] : memref<2xindex>
149147
// CHECK: %[[NP:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
150-
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
148+
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[Sizes]], %[[Sizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
151149
// CHECK: return %[[T]] : !llvm.ptr<i8>
152150
func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #CSR> {
153151
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSR>

0 commit comments

Comments
 (0)