Skip to content

Commit 3cf03f1

Browse files
committed
[mlir][sparse] Adding IsSparseTensorPred and updating ops to use it
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126994
1 parent 604016d commit 3cf03f1

File tree

4 files changed

+58
-93
lines changed

4 files changed

+58
-93
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

+24
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
9393
}];
9494
}
9595

96+
def IsSparseTensorPred
97+
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
98+
99+
// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
100+
// `RankedTensorOf`, `AnyRankedTensor`.
101+
102+
class SparseTensorOf<list<Type> allowedTypes>
103+
: ShapedContainerType<
104+
allowedTypes,
105+
And<[IsTensorTypePred, IsSparseTensorPred]>,
106+
"sparse tensor",
107+
"::mlir::TensorType">;
108+
109+
def AnySparseTensor : SparseTensorOf<[AnyType]>;
110+
111+
class RankedSparseTensorOf<list<Type> allowedTypes>
112+
: ShapedContainerType<
113+
allowedTypes,
114+
And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
115+
"ranked sparse tensor",
116+
"::mlir::TensorType">;
117+
118+
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
119+
96120
#endif // SPARSETENSOR_ATTRDEFS

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+10-17
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
2727

2828
def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
2929
Arguments<(ins AnyType:$source)>,
30-
Results<(outs TensorOf<[AnyType]>:$result)> {
30+
Results<(outs AnySparseTensor:$result)> {
3131
string summary = "Materializes a new sparse tensor from given source";
3232
string description = [{
3333
Materializes a sparse tensor with contents taken from an opaque pointer
@@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
4646
```
4747
}];
4848
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
49-
let hasVerifier = 1;
5049
}
5150

5251
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
9291
}
9392

9493
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
95-
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
94+
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
9695
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
9796
let summary = "Extracts pointers array at given dimension from a tensor";
9897
let description = [{
@@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
117116
}
118117

119118
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
120-
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
119+
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
121120
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
122121
let summary = "Extracts indices array at given dimension from a tensor";
123122
let description = [{
@@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
142141
}
143142

144143
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
145-
Arguments<(ins AnyTensor:$tensor)>,
144+
Arguments<(ins AnySparseTensor:$tensor)>,
146145
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
147146
let summary = "Extracts numerical values array from a tensor";
148147
let description = [{
@@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
173172
//===----------------------------------------------------------------------===//
174173

175174
def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
176-
Arguments<(ins AnyTensor:$tensor,
175+
Arguments<(ins AnySparseTensor:$tensor,
177176
StridedMemRefRankOf<[Index], [1]>:$indices,
178177
AnyType:$value)> {
179178
string summary = "Inserts a value into given sparse tensor in lexicographical index order";
@@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
196195
}];
197196
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
198197
" type($tensor) `,` type($indices) `,` type($value)";
199-
let hasVerifier = 1;
200198
}
201199

202200
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
203-
Arguments<(ins AnyTensor:$tensor)>,
201+
Arguments<(ins AnySparseTensor:$tensor)>,
204202
Results<(outs AnyStridedMemRefOfRank<1>:$values,
205203
StridedMemRefRankOf<[I1],[1]>:$filled,
206204
StridedMemRefRankOf<[Index],[1]>:$added,
@@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
238236
}];
239237
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
240238
" `,` type($filled) `,` type($added) `,` type($count)";
241-
let hasVerifier = 1;
242239
}
243240

244241
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
245-
Arguments<(ins AnyTensor:$tensor,
242+
Arguments<(ins AnySparseTensor:$tensor,
246243
StridedMemRefRankOf<[Index],[1]>:$indices,
247244
AnyStridedMemRefOfRank<1>:$values,
248245
StridedMemRefRankOf<[I1],[1]>:$filled,
@@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
273270
" $added `,` $count attr-dict `:` type($tensor) `,`"
274271
" type($indices) `,` type($values) `,` type($filled) `,`"
275272
" type($added) `,` type($count)";
276-
let hasVerifier = 1;
277273
}
278274

279275
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
280-
Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
276+
Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>,
281277
Results<(outs AnyTensor:$result)> {
282278
let summary =
283279
"Rematerializes tensor from underlying sparse storage format";
@@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
306302
```
307303
}];
308304
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
309-
let hasVerifier = 1;
310305
}
311306

312307
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
313-
Arguments<(ins AnyTensor:$tensor)> {
308+
Arguments<(ins AnySparseTensor:$tensor)> {
314309
string summary = "Releases underlying sparse storage format of given tensor";
315310
string description = [{
316311
Releases the underlying sparse storage format for a tensor that
@@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
332327
```
333328
}];
334329
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
335-
let hasVerifier = 1;
336330
}
337331

338332
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
339-
Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
333+
Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> {
340334
string summary = "Outputs a sparse tensor to the given destination";
341335
string description = [{
342336
Outputs the contents of a sparse tensor to the destination defined by an
@@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
353347
```
354348
}];
355349
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
356-
let hasVerifier = 1;
357350
}
358351

359352
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

+12-64
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
208208
return failure();
209209
}
210210

211-
LogicalResult NewOp::verify() {
212-
if (!getSparseTensorEncoding(result().getType()))
213-
return emitError("expected a sparse tensor result");
214-
return success();
215-
}
216-
217211
LogicalResult ConvertOp::verify() {
218212
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
219213
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
@@ -240,77 +234,31 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
240234
}
241235

242236
LogicalResult ToPointersOp::verify() {
243-
if (auto e = getSparseTensorEncoding(tensor().getType())) {
244-
if (failed(isInBounds(dim(), tensor())))
245-
return emitError("requested pointers dimension out of bounds");
246-
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
247-
return emitError("unexpected type for pointers");
248-
return success();
249-
}
250-
return emitError("expected a sparse tensor to get pointers");
237+
auto e = getSparseTensorEncoding(tensor().getType());
238+
if (failed(isInBounds(dim(), tensor())))
239+
return emitError("requested pointers dimension out of bounds");
240+
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
241+
return emitError("unexpected type for pointers");
242+
return success();
251243
}
252244

253245
LogicalResult ToIndicesOp::verify() {
254-
if (auto e = getSparseTensorEncoding(tensor().getType())) {
255-
if (failed(isInBounds(dim(), tensor())))
256-
return emitError("requested indices dimension out of bounds");
257-
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
258-
return emitError("unexpected type for indices");
259-
return success();
260-
}
261-
return emitError("expected a sparse tensor to get indices");
246+
auto e = getSparseTensorEncoding(tensor().getType());
247+
if (failed(isInBounds(dim(), tensor())))
248+
return emitError("requested indices dimension out of bounds");
249+
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
250+
return emitError("unexpected type for indices");
251+
return success();
262252
}
263253

264254
LogicalResult ToValuesOp::verify() {
265-
if (!getSparseTensorEncoding(tensor().getType()))
266-
return emitError("expected a sparse tensor to get values");
267255
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
268256
MemRefType mtp = result().getType().cast<MemRefType>();
269257
if (ttp.getElementType() != mtp.getElementType())
270258
return emitError("unexpected mismatch in element types");
271259
return success();
272260
}
273261

274-
//===----------------------------------------------------------------------===//
275-
// TensorDialect Management Operations.
276-
//===----------------------------------------------------------------------===//
277-
278-
LogicalResult LexInsertOp::verify() {
279-
if (!getSparseTensorEncoding(tensor().getType()))
280-
return emitError("expected a sparse tensor for insertion");
281-
return success();
282-
}
283-
284-
LogicalResult ExpandOp::verify() {
285-
if (!getSparseTensorEncoding(tensor().getType()))
286-
return emitError("expected a sparse tensor for expansion");
287-
return success();
288-
}
289-
290-
LogicalResult CompressOp::verify() {
291-
if (!getSparseTensorEncoding(tensor().getType()))
292-
return emitError("expected a sparse tensor for compression");
293-
return success();
294-
}
295-
296-
LogicalResult LoadOp::verify() {
297-
if (!getSparseTensorEncoding(tensor().getType()))
298-
return emitError("expected a sparse tensor to materialize");
299-
return success();
300-
}
301-
302-
LogicalResult ReleaseOp::verify() {
303-
if (!getSparseTensorEncoding(tensor().getType()))
304-
return emitError("expected a sparse tensor to release");
305-
return success();
306-
}
307-
308-
LogicalResult OutOp::verify() {
309-
if (!getSparseTensorEncoding(tensor().getType()))
310-
return emitError("expected a sparse tensor for output");
311-
return success();
312-
}
313-
314262
//===----------------------------------------------------------------------===//
315263
// TensorDialect Linalg.Generic Operations.
316264
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SparseTensor/invalid.mlir

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

33
func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
4-
// expected-error@+1 {{expected a sparse tensor result}}
4+
// expected-error@+1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
55
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
66
return %0 : tensor<32xf32>
77
}
88

99
// -----
1010

1111
func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
12-
// expected-error@+1 {{expected a sparse tensor to release}}
12+
// expected-error@+1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}}
1313
sparse_tensor.release %arg0 : tensor<4xi32>
1414
return
1515
}
@@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
1818

1919
func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
2020
%c = arith.constant 0 : index
21-
// expected-error@+1 {{expected a sparse tensor to get pointers}}
21+
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
2222
%0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
2323
return %0 : memref<?xindex>
2424
}
@@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
2727

2828
func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
2929
%c = arith.constant 0 : index
30-
// expected-error@+1 {{expected a sparse tensor to get pointers}}
30+
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
3131
%0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
3232
return %0 : memref<?xindex>
3333
}
@@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex
5858

5959
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
6060
%c = arith.constant 1 : index
61-
// expected-error@+1 {{expected a sparse tensor to get indices}}
61+
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
6262
%0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
6363
return %0 : memref<?xindex>
6464
}
@@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
6767

6868
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
6969
%c = arith.constant 0 : index
70-
// expected-error@+1 {{expected a sparse tensor to get indices}}
70+
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
7171
%0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
7272
return %0 : memref<?xindex>
7373
}
@@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
9797
// -----
9898

9999
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
100-
// expected-error@+1 {{expected a sparse tensor to get values}}
100+
// expected-error@+1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
101101
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
102102
return %0 : memref<?xf32>
103103
}
@@ -115,23 +115,23 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<
115115
// -----
116116

117117
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
118-
// expected-error@+1 {{expected a sparse tensor to materialize}}
118+
// expected-error@+1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
119119
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
120120
return %0 : tensor<16x32xf64>
121121
}
122122

123123
// -----
124124

125125
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
126-
// expected-error@+1 {{expected a sparse tensor for insertion}}
126+
// expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
127127
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
128128
return
129129
}
130130

131131
// -----
132132

133133
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
134-
// expected-error@+1 {{expected a sparse tensor for expansion}}
134+
// expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
135135
%values, %filled, %added, %count = sparse_tensor.expand %arg0
136136
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
137137
return
@@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
142142
func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
143143
%arg2: memref<?xf64>, %arg3: memref<?xi1>,
144144
%arg4: memref<?xindex>, %arg5: index) {
145-
// expected-error@+1 {{expected a sparse tensor for compression}}
145+
// expected-error@+1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
146146
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
147147
: tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
148148
}
@@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
178178
// -----
179179

180180
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
181-
// expected-error@+1 {{expected a sparse tensor for output}}
181+
// expected-error@+1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
182182
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
183183
return
184184
}

0 commit comments

Comments
 (0)