|
| 1 | +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py |
| 2 | +// RUN: mlir-opt %s -sparsification | FileCheck %s |
| 3 | + |
| 4 | +#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> |
| 5 | + |
| 6 | +// A contrived example that demonstrates the many different ways |
| 7 | +// in which scalar values can be involved in a sparse kernel |
| 8 | +// through the linalg generic op. |
| 9 | + |
| 10 | +#trait = { |
| 11 | + indexing_maps = [ |
| 12 | + affine_map<(i,j) -> (i,j)>, // A (sparse tensor) |
| 13 | + affine_map<(i,j) -> ()>, // p (scalar tensor) |
| 14 | + affine_map<(i,j) -> ()>, // q (true scalar) |
| 15 | + affine_map<(i,j) -> (i,j)> // X (dense tensor out) |
| 16 | + ], |
| 17 | + iterator_types = ["parallel", "parallel"], |
| 18 | + doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2" |
| 19 | +} |
| 20 | + |
| 21 | +// CHECK-LABEL: func @mul( |
| 22 | +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>, |
| 23 | +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<f32>, |
| 24 | +// CHECK-SAME: %[[VAL_2:.*2]]: f32, |
| 25 | +// CHECK-SAME: %[[VAL_3:.*3]]: f32, |
| 26 | +// CHECK-SAME: %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> { |
| 27 | +// CHECK: %[[VAL_5:.*]] = constant 2.200000e+00 : f32 |
| 28 | +// CHECK: %[[VAL_6:.*]] = constant 0 : index |
| 29 | +// CHECK: %[[VAL_7:.*]] = constant 1 : index |
| 30 | +// CHECK: %[[VAL_8:.*]] = addf %[[VAL_2]], %[[VAL_3]] : f32 |
| 31 | +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex> |
| 32 | +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex> |
| 33 | +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex> |
| 34 | +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex> |
| 35 | +// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref<?xf32> |
| 36 | +// CHECK: %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32> |
| 37 | +// CHECK: %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_4]] : memref<32x16xf32> |
| 38 | +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32> |
| 39 | +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex> |
| 40 | +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex> |
| 41 | +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] { |
| 42 | +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex> |
| 43 | +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex> |
| 44 | +// CHECK: %[[VAL_22:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index |
| 45 | +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex> |
| 46 | +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] { |
| 47 | +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex> |
| 48 | +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32> |
| 49 | +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_26]], %[[VAL_16]] : f32 |
| 50 | +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_27]], %[[VAL_2]] : f32 |
| 51 | +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_28]], %[[VAL_3]] : f32 |
| 52 | +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_29]], %[[VAL_8]] : f32 |
| 53 | +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_30]], %[[VAL_5]] : f32 |
| 54 | +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> |
| 55 | +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32 |
| 56 | +// CHECK: memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> |
| 57 | +// CHECK: } |
| 58 | +// CHECK: } |
| 59 | +// CHECK: %[[VAL_34:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32> |
| 60 | +// CHECK: return %[[VAL_34]] : tensor<32x16xf32> |
| 61 | +// CHECK: } |
| 62 | +func @mul(%arga: tensor<32x16xf32, #SparseMatrix>, |
| 63 | + %argp: tensor<f32>, |
| 64 | + %argq: f32, |
| 65 | + %argr: f32, |
| 66 | + %argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> { |
| 67 | + %s = addf %argq, %argr : f32 |
| 68 | + %c = constant 2.2 : f32 |
| 69 | + %0 = linalg.generic #trait |
| 70 | + ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32) |
| 71 | + outs(%argx: tensor<32x16xf32>) { |
| 72 | + ^bb(%a: f32, %p: f32, %q: f32, %x: f32): |
| 73 | + %0 = mulf %a, %p : f32 // scalar tensor argument |
| 74 | + %1 = mulf %0, %q : f32 // scalar argument |
| 75 | + %2 = mulf %1, %argr : f32 // scalar argument from outside block |
| 76 | + %3 = mulf %2, %s : f32 // scalar value from outside block |
| 77 | + %4 = mulf %3, %c : f32 // direct constant from outside block |
| 78 | + %5 = addf %4, %x : f32 |
| 79 | + linalg.yield %5 : f32 |
| 80 | + } -> tensor<32x16xf32> |
| 81 | + |
| 82 | + return %0 : tensor<32x16xf32> |
| 83 | +} |
0 commit comments