-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][sparse] force a properly sized view on pos/crd/val under codegen #91288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir-gpu Author: Aart Bik (aartbik) ChangesCodegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size. Patch is 53.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91288.diff 10 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5679f277e14866..339f1d31adabcc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1050,10 +1050,14 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested position access with corresponding field.
- // The cast_op is inserted by type converter to intermix 1:N type
- // conversion.
+ // The view is restricted to the actual size to ensure clients
+ // of this operation truly obserview size, not capacity!
+ Location loc = op.getLoc();
+ Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
+ auto mem = desc.getPosMemRef(lvl);
+ auto size = desc.getPosMemSize(rewriter, loc, lvl);
+ rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
@@ -1068,12 +1072,17 @@ class SparseToCoordinatesConverter
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
- // The cast_op is inserted by type converter to intermix 1:N type
- // conversion.
+ // The view is restricted to the actual size to ensure clients
+ // of this operation truly obserview size, not capacity!
+ Location loc = op.getLoc();
+ Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- rewriter.replaceOp(
- op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
-
+ auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
+ if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
+ auto size = desc.getCrdMemSize(rewriter, loc, lvl);
+ mem = genSliceToSize(rewriter, loc, mem, size);
+ }
+ rewriter.replaceOp(op, mem);
return success();
}
};
@@ -1088,11 +1097,14 @@ class SparseToCoordinatesBufferConverter
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
- // The cast_op is inserted by type converter to intermix 1:N type
- // conversion.
+ // The view is restricted to the actual size to ensure clients
+ // of this operation truly obserview size, not capacity!
+ Location loc = op.getLoc();
+ Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- rewriter.replaceOp(op, desc.getAOSMemRef());
-
+ auto mem = desc.getAOSMemRef();
+ auto size = desc.getCrdMemSize(rewriter, loc, lvl);
+ rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
@@ -1106,10 +1118,13 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
- // The cast_op is inserted by type converter to intermix 1:N type
- // conversion.
+ // The view is restricted to the actual size to ensure clients
+ // of this operation truly obserview size, not capacity!
+ Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- rewriter.replaceOp(op, desc.getValMemRef());
+ auto mem = desc.getValMemRef();
+ auto size = desc.getValMemSize(rewriter, loc);
+ rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/binary_valued.mlir b/mlir/test/Dialect/SparseTensor/binary_valued.mlir
index e2d410b126a775..dd9b60a6488b6f 100755
--- a/mlir/test/Dialect/SparseTensor/binary_valued.mlir
+++ b/mlir/test/Dialect/SparseTensor/binary_valued.mlir
@@ -26,12 +26,11 @@
//
// Make sure X += A * A => X += 1 in single loop.
//
-//
// CHECK-LABEL: func.func @sum_squares(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
-// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
+// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
@@ -40,23 +39,25 @@
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
-// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
-// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_4]] : f32
-// CHECK: scf.yield %[[VAL_26]] : f32
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
+// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_7]] : index
+// CHECK: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : index
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (f32) {
+// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_4]] : f32
+// CHECK: scf.yield %[[VAL_28]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_23]] : f32
+// CHECK: scf.yield %[[VAL_25]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_16]] : f32
+// CHECK: scf.yield %[[VAL_18]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK: memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
// CHECK: return %[[VAL_10]] : memref<f32>
// CHECK: }
//
@@ -99,25 +100,29 @@ func.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> {
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
-// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
-// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
-// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_24]]] : memref<?xindex>
-// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_13]], %[[VAL_17]], %[[VAL_26]]] : memref<2x3x8xf32>
-// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_25]] : f32
-// CHECK: scf.yield %[[VAL_28]] : f32
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_13]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_7]] : index
+// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_18]]) -> (f32) {
+// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_19]] : index
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_22]]) -> (f32) {
+// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_21]], %[[VAL_30]]] : memref<2x3x8xf32>
+// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_31]], %[[VAL_29]] : f32
+// CHECK: scf.yield %[[VAL_32]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_23]] : f32
+// CHECK: scf.yield %[[VAL_27]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: scf.yield %[[VAL_16]] : f32
+// CHECK: scf.yield %[[VAL_20]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK: memref.store %[[VAL_16]], %[[VAL_10]][] : memref<f32>
// CHECK: return %[[VAL_10]] : memref<f32>
// CHECK: }
//
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 40bfa1e4e2a501..af78458f109329 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -266,7 +266,9 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-// CHECK: return %[[A2]] : memref<?xi32>
+// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] pos_mem_sz at 1
+// CHECK: %[[V:.*]] = memref.subview %[[A2]][0] [%[[S]]] [1]
+// CHECK: return %[[V]] : memref<?xi32>
func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
%0 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
return %0 : memref<?xi32>
@@ -279,7 +281,9 @@ func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32>
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-// CHECK: return %[[A3]] : memref<?xi64>
+// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
+// CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
+// CHECK: return %[[V]] : memref<?xi64>
func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
%0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
return %0 : memref<?xi64>
@@ -292,7 +296,9 @@ func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-// CHECK: return %[[A4]] : memref<?xf64>
+// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
+// CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
+// CHECK: return %[[V]] : memref<?xf64>
func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
@@ -305,13 +311,14 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-// CHECK: return %[[A4]] : memref<?xf64>
+// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
+// CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
+// CHECK: return %[[V]] : memref<?xf64>
func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xf64>
return %0 : memref<?xf64>
}
-
// CHECK-LABEL: func.func @sparse_indices_coo(
// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
@@ -320,7 +327,7 @@ func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
+// CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
// CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index
// CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref<?xindex> to memref<?xindex, strided<[2]>>
// CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref<?xindex, strided<[2]>> to memref<?xindex, strided<[?], offset: ?>>
@@ -337,7 +344,9 @@ func.func @sparse_indices_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-// CHECK: return %[[A3]] : memref<?xindex>
+// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
+// CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
+// CHECK: return %[[V]] : memref<?xindex>
func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex> {
%0 = sparse_tensor.coordinates_buffer %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xindex>
return %0 : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 5145d6c1dcfc32..ad12b637d0c522 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -1,5 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --sparse-reinterpret-map --sparsification --sparse-tensor-codegen \
// RUN: --canonicalize --cse | FileCheck %s
@@ -11,45 +9,6 @@
//
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
//
-// CHECK-LABEL: func.func private @_insert_dense_compressed_4_4_f64_0_0(
-// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
-// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf64>,
-// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
-// CHECK-SAME: %[[VAL_4:.*4]]: index,
-// CHECK-SAME: %[[VAL_5:.*5]]: index,
-// CHECK-SAME: %[[VAL_6:.*6]]: f64) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_7:.*]] = arith.constant false
-// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] crd_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
-// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
-// CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index
-// CHECK: scf.yield %[[VAL_18]] : i1
-// CHECK: } else {
-// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK: scf.yield %[[VAL_7]] : i1
-// CHECK: }
-// CHECK: %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref<?xindex>, !sparse_tensor.storage_specifier
-// CHECK: scf.yield %[[VAL_1]], %[[VAL_3]] : memref<?xindex>, !sparse_tensor.storage_specifier
-// CHECK: } else {
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
-// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK: %[[VAL_22:.*]], %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
-// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] crd_mem_sz at 1 with %[[VAL_24]] : !sparse_tensor.storage_specifier
-// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_t...
[truncated]
|
Codegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size.