Skip to content

Commit 619bfe8

Browse files
committed
[mlir][sparse] support new kind of scalar in sparse linalg generic op
We have several ways of introducing a scalar invariant value into linalg generic ops (should we limit this somewhat?). This revision makes sure we handle all of them correctly in the sparse compiler. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D104335
1 parent a993bb0 commit 619bfe8

File tree

2 files changed

+99
-16
lines changed

2 files changed

+99
-16
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,17 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
458458
Value val) {
459459
if (auto arg = val.dyn_cast<BlockArgument>()) {
460460
unsigned argN = arg.getArgNumber();
461-
// Any parameter of the generic op is considered a tensor,
462-
// indexed by the implicit loop bounds.
463-
if (arg.getOwner()->getParentOp() == op)
464-
return merger.addExp(Kind::kTensor, argN);
465-
// Any parameter of a higher op is invariant.
461+
// Any argument of the generic op that is not marked as a scalar
462+
// argument is considered a tensor, indexed by the implicit loop
463+
// bounds. This includes rank-0 tensor arguments.
464+
if (arg.getOwner()->getParentOp() == op) {
465+
OpOperand *t = op.getInputAndOutputOperands()[argN];
466+
if (!op.isScalar(t))
467+
return merger.addExp(Kind::kTensor, argN);
468+
val = t->get(); // get scalar value
469+
}
470+
// Any other argument (marked as scalar argument for the generic op
471+
// or belonging to an enveloping op) is considered invariant.
466472
return merger.addExp(Kind::kInvariant, val);
467473
}
468474
Operation *def = val.getDefiningOp();
@@ -719,9 +725,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
719725
}
720726
// Actual load.
721727
SmallVector<Value, 4> args;
722-
OpOperand *t = merger.exp(exp).e0 < op.getNumInputs()
723-
? op.getInputOperand(merger.exp(exp).e0)
724-
: op.getOutputOperand(0);
728+
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
725729
unsigned tensor = t->getOperandNumber();
726730
auto map = op.getTiedIndexingMap(t);
727731
auto enc = getSparseTensorEncoding(t->get().getType());
@@ -919,11 +923,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
919923
if (merger.exp(exp).kind == Kind::kTensor) {
920924
// Inspect tensor indices.
921925
bool atLevel = ldx == -1u;
922-
OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs()
923-
? op.getInputOperand(merger.exp(exp).e0)
924-
: op.getOutputOperand(0);
925-
auto map = op.getTiedIndexingMap(tensor);
926-
auto enc = getSparseTensorEncoding(tensor->get().getType());
926+
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0];
927+
auto map = op.getTiedIndexingMap(t);
928+
auto enc = getSparseTensorEncoding(t->get().getType());
927929
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
928930
unsigned idx = map.getDimPosition(perm(enc, d));
929931
if (!codegen.loops[idx])
@@ -933,7 +935,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
933935
}
934936
// All exhausted at this level (atLevel denotes exactly at this level).
935937
OpOperand *lhs = op.getOutputOperand(0);
936-
if (lhs == tensor) {
938+
if (lhs == t) {
937939
codegen.redExp = hoist ? exp : -1u;
938940
} else if (atLevel) {
939941
merger.exp(exp).val =
@@ -1413,8 +1415,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
14131415
// Detects sparse annotations and translate the per-dimension sparsity
14141416
// information for all tensors to loop indices in the kernel.
14151417
assert(op.getNumOutputs() == 1);
1416-
assert(llvm::none_of(op.getInputAndOutputOperands(),
1417-
[&](OpOperand *t) { return op.isScalar(t); }));
14181418
unsigned numTensors = op.getNumInputsAndOutputs();
14191419
unsigned numLoops = op.iterator_types().getValue().size();
14201420
Merger merger(numTensors, numLoops);
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2+
// RUN: mlir-opt %s -sparsification | FileCheck %s
3+
4+
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
5+
6+
// A contrived example that demonstrates the many different ways
7+
// in which scalar values can be involved in a sparse kernel
8+
// through the linalg generic op.
9+
10+
#trait = {
11+
indexing_maps = [
12+
affine_map<(i,j) -> (i,j)>, // A (sparse tensor)
13+
affine_map<(i,j) -> ()>, // p (scalar tensor)
14+
affine_map<(i,j) -> ()>, // q (true scalar)
15+
affine_map<(i,j) -> (i,j)> // X (dense tensor out)
16+
],
17+
iterator_types = ["parallel", "parallel"],
18+
doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2"
19+
}
20+
21+
// CHECK-LABEL: func @mul(
22+
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>,
23+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<f32>,
24+
// CHECK-SAME: %[[VAL_2:.*2]]: f32,
25+
// CHECK-SAME: %[[VAL_3:.*3]]: f32,
26+
// CHECK-SAME: %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
27+
// CHECK: %[[VAL_5:.*]] = constant 2.200000e+00 : f32
28+
// CHECK: %[[VAL_6:.*]] = constant 0 : index
29+
// CHECK: %[[VAL_7:.*]] = constant 1 : index
30+
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_2]], %[[VAL_3]] : f32
31+
// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
32+
// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
33+
// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
34+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
35+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32>
36+
// CHECK: %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
37+
// CHECK: %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_4]] : memref<32x16xf32>
38+
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32>
39+
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
40+
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
41+
// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] {
42+
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
43+
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex>
44+
// CHECK: %[[VAL_22:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index
45+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
46+
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] {
47+
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
48+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
49+
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_26]], %[[VAL_16]] : f32
50+
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_27]], %[[VAL_2]] : f32
51+
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_28]], %[[VAL_3]] : f32
52+
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_29]], %[[VAL_8]] : f32
53+
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_30]], %[[VAL_5]] : f32
54+
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
55+
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
56+
// CHECK: memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
57+
// CHECK: }
58+
// CHECK: }
59+
// CHECK: %[[VAL_34:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32>
60+
// CHECK: return %[[VAL_34]] : tensor<32x16xf32>
61+
// CHECK: }
62+
func @mul(%arga: tensor<32x16xf32, #SparseMatrix>,
63+
%argp: tensor<f32>,
64+
%argq: f32,
65+
%argr: f32,
66+
%argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> {
67+
%s = addf %argq, %argr : f32
68+
%c = constant 2.2 : f32
69+
%0 = linalg.generic #trait
70+
ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32)
71+
outs(%argx: tensor<32x16xf32>) {
72+
^bb(%a: f32, %p: f32, %q: f32, %x: f32):
73+
%0 = mulf %a, %p : f32 // scalar tensor argument
74+
%1 = mulf %0, %q : f32 // scalar argument
75+
%2 = mulf %1, %argr : f32 // scalar argument from outside block
76+
%3 = mulf %2, %s : f32 // scalar value from outside block
77+
%4 = mulf %3, %c : f32 // direct constant from outside block
78+
%5 = addf %4, %x : f32
79+
linalg.yield %5 : f32
80+
} -> tensor<32x16xf32>
81+
82+
return %0 : tensor<32x16xf32>
83+
}

0 commit comments

Comments
 (0)