Skip to content

[mlir][sparse] force a properly sized view on pos/crd/val under codegen #91288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2024

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented May 7, 2024

Codegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size.

Codegen "vectors" for pos/crd/val use the capacity as memref size, not
the actual used size. Although the sparsifier itself always uses just
the defined pos/crd/val parts, printing these and passing them back
to a runtime environment could benefit from wrapping the basic
pos/crd/val getters into a proper memref view that sets the right size.
@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Aart Bik (aartbik)

Changes

Codegen "vectors" for pos/crd/val use the capacity as memref size, not the actual used size. Although the sparsifier itself always uses just the defined pos/crd/val parts, printing these and passing them back to a runtime environment could benefit from wrapping the basic pos/crd/val getters into a proper memref view that sets the right size.


Patch is 53.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91288.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+30-15)
  • (modified) mlir/test/Dialect/SparseTensor/binary_valued.mlir (+37-32)
  • (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+16-7)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+71-99)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir (+10-10)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_empty.mlir (+11-11)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir (+51-51)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir (+6-6)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir (+6-6)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 5679f277e14866..339f1d31adabcc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1050,10 +1050,14 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
   matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested position access with corresponding field.
-    // The cast_op is inserted by type converter to intermix 1:N type
-    // conversion.
+    // The view is restricted to the actual size to ensure clients
+    // of this operation truly obserview size, not capacity!
+    Location loc = op.getLoc();
+    Level lvl = op.getLevel();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
+    auto mem = desc.getPosMemRef(lvl);
+    auto size = desc.getPosMemSize(rewriter, loc, lvl);
+    rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
     return success();
   }
 };
@@ -1068,12 +1072,17 @@ class SparseToCoordinatesConverter
   matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
-    // The cast_op is inserted by type converter to intermix 1:N type
-    // conversion.
+    // The view is restricted to the actual size to ensure clients
+    // of this operation truly obserview size, not capacity!
+    Location loc = op.getLoc();
+    Level lvl = op.getLevel();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    rewriter.replaceOp(
-        op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
-
+    auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
+    if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
+      auto size = desc.getCrdMemSize(rewriter, loc, lvl);
+      mem = genSliceToSize(rewriter, loc, mem, size);
+    }
+    rewriter.replaceOp(op, mem);
     return success();
   }
 };
@@ -1088,11 +1097,14 @@ class SparseToCoordinatesBufferConverter
   matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
-    // The cast_op is inserted by type converter to intermix 1:N type
-    // conversion.
+    // The view is restricted to the actual size to ensure clients
+    // of this operation truly obserview size, not capacity!
+    Location loc = op.getLoc();
+    Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    rewriter.replaceOp(op, desc.getAOSMemRef());
-
+    auto mem = desc.getAOSMemRef();
+    auto size = desc.getCrdMemSize(rewriter, loc, lvl);
+    rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
     return success();
   }
 };
@@ -1106,10 +1118,13 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested values access with corresponding field.
-    // The cast_op is inserted by type converter to intermix 1:N type
-    // conversion.
+    // The view is restricted to the actual size to ensure clients
+    // of this operation truly obserview size, not capacity!
+    Location loc = op.getLoc();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    rewriter.replaceOp(op, desc.getValMemRef());
+    auto mem = desc.getValMemRef();
+    auto size = desc.getValMemSize(rewriter, loc);
+    rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/binary_valued.mlir b/mlir/test/Dialect/SparseTensor/binary_valued.mlir
index e2d410b126a775..dd9b60a6488b6f 100755
--- a/mlir/test/Dialect/SparseTensor/binary_valued.mlir
+++ b/mlir/test/Dialect/SparseTensor/binary_valued.mlir
@@ -26,12 +26,11 @@
 //
 // Make sure X += A * A => X += 1 in single loop.
 //
-//
 // CHECK-LABEL:   func.func @sum_squares(
 // CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
 // CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf32>,
-// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
+// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
@@ -40,23 +39,25 @@
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
 // CHECK:           linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
-// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK:             %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
-// CHECK:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
-// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:               %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
-// CHECK:                 %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_4]] : f32
-// CHECK:                 scf.yield %[[VAL_26]] : f32
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK:           %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
+// CHECK:             %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:               %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (f32) {
+// CHECK:                 %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_4]] : f32
+// CHECK:                 scf.yield %[[VAL_28]] : f32
 // CHECK:               } {"Emitted from" = "linalg.generic"}
-// CHECK:               scf.yield %[[VAL_23]] : f32
+// CHECK:               scf.yield %[[VAL_25]] : f32
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             scf.yield %[[VAL_16]] : f32
+// CHECK:             scf.yield %[[VAL_18]] : f32
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK:           memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
 // CHECK:           return %[[VAL_10]] : memref<f32>
 // CHECK:         }
 //
@@ -99,25 +100,29 @@ func.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> {
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
 // CHECK:           linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
-// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
-// CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
-// CHECK:             %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
-// CHECK:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
-// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
-// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:               %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
-// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_24]]] : memref<?xindex>
-// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_13]], %[[VAL_17]], %[[VAL_26]]] : memref<2x3x8xf32>
-// CHECK:                 %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_25]] : f32
-// CHECK:                 scf.yield %[[VAL_28]] : f32
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK:           %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
+// CHECK:           %[[VAL_14:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_13]]] [1] : memref<?xindex> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
+// CHECK:           %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK:             %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_18]]) -> (f32) {
+// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_19]] : index
+// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:               %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_22]]) -> (f32) {
+// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_21]], %[[VAL_30]]] : memref<2x3x8xf32>
+// CHECK:                 %[[VAL_32:.*]] = arith.addf %[[VAL_31]], %[[VAL_29]] : f32
+// CHECK:                 scf.yield %[[VAL_32]] : f32
 // CHECK:               } {"Emitted from" = "linalg.generic"}
-// CHECK:               scf.yield %[[VAL_23]] : f32
+// CHECK:               scf.yield %[[VAL_27]] : f32
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             scf.yield %[[VAL_16]] : f32
+// CHECK:             scf.yield %[[VAL_20]] : f32
 // CHECK:           } {"Emitted from" = "linalg.generic"}
-// CHECK:           memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
+// CHECK:           memref.store %[[VAL_16]], %[[VAL_10]][] : memref<f32>
 // CHECK:           return %[[VAL_10]] : memref<f32>
 // CHECK:         }
 //
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 40bfa1e4e2a501..af78458f109329 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -266,7 +266,9 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-//       CHECK: return %[[A2]] : memref<?xi32>
+//       CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] pos_mem_sz at 1
+//       CHECK: %[[V:.*]] = memref.subview %[[A2]][0] [%[[S]]] [1]
+//       CHECK: return %[[V]] : memref<?xi32>
 func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
   %0 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
   return %0 : memref<?xi32>
@@ -279,7 +281,9 @@ func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32>
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-//       CHECK: return %[[A3]] : memref<?xi64>
+//       CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
+//       CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
+//       CHECK: return %[[V]] : memref<?xi64>
 func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
   %0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
   return %0 : memref<?xi64>
@@ -292,7 +296,9 @@ func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-//       CHECK: return %[[A4]] : memref<?xf64>
+//       CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
+//       CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
+//       CHECK: return %[[V]] : memref<?xf64>
 func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
   return %0 : memref<?xf64>
@@ -305,13 +311,14 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-//       CHECK: return %[[A4]] : memref<?xf64>
+//       CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
+//       CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
+//       CHECK: return %[[V]] : memref<?xf64>
 func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
   %0 = sparse_tensor.values %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xf64>
   return %0 : memref<?xf64>
 }
 
-
 // CHECK-LABEL: func.func @sparse_indices_coo(
 //  CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
@@ -320,7 +327,7 @@ func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
 //       CHECK: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]]  crd_mem_sz at 1
+//       CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
 //       CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index
 //       CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref<?xindex> to memref<?xindex, strided<[2]>>
 //       CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref<?xindex, strided<[2]>> to memref<?xindex, strided<[?], offset: ?>>
@@ -337,7 +344,9 @@ func.func @sparse_indices_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
 //  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
 //  CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
-//       CHECK: return %[[A3]] : memref<?xindex>
+//       CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
+//       CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
+//       CHECK: return %[[V]] : memref<?xindex>
 func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex> {
   %0 = sparse_tensor.coordinates_buffer  %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xindex>
   return %0 : memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 5145d6c1dcfc32..ad12b637d0c522 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -1,5 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-
 // RUN: mlir-opt %s --linalg-generalize-named-ops \
 // RUN:  --sparse-reinterpret-map --sparsification --sparse-tensor-codegen \
 // RUN:  --canonicalize --cse | FileCheck %s
@@ -11,45 +9,6 @@
 //
 // Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
 //
-// CHECK-LABEL:   func.func private @_insert_dense_compressed_4_4_f64_0_0(
-// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
-// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
-// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf64>,
-// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
-// CHECK-SAME:      %[[VAL_4:.*4]]: index,
-// CHECK-SAME:      %[[VAL_5:.*5]]: index,
-// CHECK-SAME:      %[[VAL_6:.*6]]: f64) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_7:.*]] = arith.constant false
-// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  crd_mem_sz at 1 : !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
-// CHECK:           %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
-// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK:             %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index
-// CHECK:             scf.yield %[[VAL_18]] : i1
-// CHECK:           } else {
-// CHECK:             memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:             scf.yield %[[VAL_7]] : i1
-// CHECK:           }
-// CHECK:           %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref<?xindex>, !sparse_tensor.storage_specifier
-// CHECK:             scf.yield %[[VAL_1]], %[[VAL_3]] : memref<?xindex>, !sparse_tensor.storage_specifier
-// CHECK:           } else {
-// CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
-// CHECK:             memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
-// CHECK:             %[[VAL_22:.*]], %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
-// CHECK:             %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]]  crd_mem_sz at 1 with %[[VAL_24]] : !sparse_tensor.storage_specifier
-// CHECK:             scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_t...
[truncated]

@aartbik aartbik merged commit 5c51165 into llvm:main May 7, 2024
4 checks passed
@aartbik aartbik deleted the bik branch May 7, 2024 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:gpu mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants