Skip to content

[mlir][sparse] support 'batch' dimensions in sparse_tensor.print #91411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented May 7, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/91411.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+9-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+41-25)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir (+6-6)
  • (added) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir (+74)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d9b203a886488..164e722c45dba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
 /// Generates a subview into the sizes.
 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
                             Value sz) {
-  auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
+  auto memTp = llvm::cast<MemRefType>(mem.getType());
+  // For higher-dimensional memrefs, we assume that the innermost
+  // dimension is always of the right size.
+  // TODO: generate complex truncating view here too?
+  if (memTp.getRank() > 1)
+    return mem;
+  // Truncate linear memrefs to given size.
   return builder
       .create<memref::SubViewOp>(
-          loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
-          ValueRange{}, ValueRange{sz}, ValueRange{},
+          loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
+          mem, ValueRange{}, ValueRange{sz}, ValueRange{},
           ArrayRef<int64_t>{0},                    // static offset
           ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
           ArrayRef<int64_t>{1})                    // static stride
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7d469198a653c..025fd3331ba89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -785,45 +785,61 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
   }
 
 private:
-  // Helper to print contents of a single memref. Note that for the "push_back"
-  // vectors, this prints the full capacity, not just the size. This is done
-  // on purpose, so that clients see how much storage has been allocated in
-  // total. Contents of the extra capacity in the buffer may be uninitialized
-  // (unless the flag enable-buffer-initialization is set to true).
+  // Helper to print contents of a single memref. For "push_back" vectors,
+  // we assume that the previous getters for pos/crd/val have added a
+  // slice-to-size view to make sure we just print the size and not the
+  // full capacity.
   //
-  // Generates code to print:
+  // Generates code to print (1-dim or higher):
   //    ( a0, a1, ... )
   static void printContents(PatternRewriter &rewriter, Location loc,
                             Value vec) {
+    auto shape = cast<ShapedType>(vec.getType()).getShape();
+    SmallVector<Value> idxs;
+    printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+  }
+
+  // Helper to the helper.
+  static void printContentsLevel(PatternRewriter &rewriter, Location loc,
+                                 Value vec, unsigned i, ArrayRef<int64_t> shape,
+                                 SmallVectorImpl<Value> &idxs) {
     // Open bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
-    // For loop over elements.
+    // Generate for loop.
     auto zero = constantIndex(rewriter, loc, 0);
-    auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+    auto index = constantIndex(rewriter, loc, i);
+    auto size = rewriter.create<memref::DimOp>(loc, vec, index);
     auto step = constantIndex(rewriter, loc, 1);
     auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+    idxs.push_back(forOp.getInductionVar());
     rewriter.setInsertionPointToStart(forOp.getBody());
-    auto idx = forOp.getInductionVar();
-    auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
-    if (llvm::isa<ComplexType>(val.getType())) {
-      // Since the vector dialect does not support complex types in any op,
-      // we split those into (real, imag) pairs here.
-      Value real = rewriter.create<complex::ReOp>(loc, val);
-      Value imag = rewriter.create<complex::ImOp>(loc, val);
-      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
-      rewriter.create<vector::PrintOp>(loc, real,
-                                       vector::PrintPunctuation::Comma);
-      rewriter.create<vector::PrintOp>(loc, imag,
-                                       vector::PrintPunctuation::Close);
-      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+    if (i < shape.size() - 1) {
+      // Enter deeper loop nest.
+      printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
     } else {
-      rewriter.create<vector::PrintOp>(loc, val,
-                                       vector::PrintPunctuation::Comma);
+      // Actual contents printing.
+      auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
+      if (llvm::isa<ComplexType>(val.getType())) {
+        // Since the vector dialect does not support complex types in any op,
+        // we split those into (real, imag) pairs here.
+        Value real = rewriter.create<complex::ReOp>(loc, val);
+        Value imag = rewriter.create<complex::ImOp>(loc, val);
+        rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+        rewriter.create<vector::PrintOp>(loc, real,
+                                         vector::PrintPunctuation::Comma);
+        rewriter.create<vector::PrintOp>(loc, imag,
+                                         vector::PrintPunctuation::Close);
+        rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+      } else {
+        rewriter.create<vector::PrintOp>(loc, val,
+                                         vector::PrintPunctuation::Comma);
+      }
     }
+    idxs.pop_back();
     rewriter.setInsertionPointAfter(forOp);
-    // Close bracket and end of line.
+    // Close bracket.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
-    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
   }
 
   // Helper method to print run-time lvl/dim sizes.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
index 20ae7e86285cc..467a77f30777a 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
@@ -29,7 +29,7 @@
   crdWidth = 32
 }>
 
-#BatchedCSR = #sparse_tensor.encoding<{
+#DenseCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
   posWidth = 64,
   crdWidth = 32
@@ -42,7 +42,7 @@
 }>
 
 //
-// Test assembly operation with CCC, batched-CSR and CSR-dense.
+// Test assembly operation with CCC, dense-CSR and CSR-dense.
 //
 module {
   //
@@ -77,7 +77,7 @@ module {
         tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
 
     //
-    // Setup BatchedCSR.
+    // Setup DenseCSR.
     //
 
     %data1 = arith.constant dense<
@@ -88,7 +88,7 @@ module {
     %crd1 = arith.constant dense<
        [ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
 
-    %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
+    %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>
 
     //
     // Setup CSRDense.
@@ -137,7 +137,7 @@ module {
     // CHECK-NEXT: ----
     //
     sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
-    sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
+    sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
     sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>
 
     // TODO: This check is no longer needed once the codegen path uses the
@@ -148,7 +148,7 @@ module {
       // sparse_tensor.assemble copies buffers when running with the runtime
       // library. Deallocations are not needed when running in codegen mode.
       bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
-      bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
+      bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
       bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
     }
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
new file mode 100755
index 0000000000000..98dee304fa511
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
@@ -0,0 +1,74 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// TODO: make this work with libgen
+
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#BatchedCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
+}>
+
+module {
+
+  //
+  // Main driver that tests 3-D sparse tensor printing.
+  //
+  func.func @main() {
+
+    %pos = arith.constant dense<
+      [[ 0, 8, 16, 24, 32],
+       [ 0, 8, 16, 24, 32]]
+    > : tensor<2x5xindex>
+
+    %crd = arith.constant dense<
+      [[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
+       [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
+    > : tensor<2x32xindex>
+
+    %val = arith.constant dense<
+      [[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.,
+        12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
+        23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
+       [33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
+        44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
+        55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
+    > : tensor<2x32xf64>
+
+    %X = sparse_tensor.assemble (%pos, %crd), %val
+      : (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>
+
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 32
+    // CHECK-NEXT: dim = ( 2, 4, 8 )
+    // CHECK-NEXT: lvl = ( 2, 4, 8 )
+    // CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32,  )( 0, 8, 16, 24, 32,  ) )
+    // CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,  )
+    // CHECK-SAME:            ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,  ) )
+    // CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,  )
+    // CHECK-SAME:            ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,  ) )
+    // CHECK-NEXT: ----
+    sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>
+
+    return
+  }
+}

@aartbik aartbik merged commit c4e5a8a into llvm:main May 8, 2024
6 of 7 checks passed
@aartbik aartbik deleted the bik branch May 8, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants