Skip to content

Commit c4e5a8a

Browse files
authored
[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)
1 parent 584253c commit c4e5a8a

File tree

4 files changed

+130
-34
lines changed

4 files changed

+130
-34
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
417417
/// Generates a subview into the sizes.
418418
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
419419
Value sz) {
420-
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
420+
auto memTp = llvm::cast<MemRefType>(mem.getType());
421+
// For higher-dimensional memrefs, we assume that the innermost
422+
// dimension is always of the right size.
423+
// TODO: generate complex truncating view here too?
424+
if (memTp.getRank() > 1)
425+
return mem;
426+
// Truncate linear memrefs to given size.
421427
return builder
422428
.create<memref::SubViewOp>(
423-
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
424-
ValueRange{}, ValueRange{sz}, ValueRange{},
429+
loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
430+
mem, ValueRange{}, ValueRange{sz}, ValueRange{},
425431
ArrayRef<int64_t>{0}, // static offset
426432
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
427433
ArrayRef<int64_t>{1}) // static stride

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -785,45 +785,61 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
785785
}
786786

787787
private:
788-
// Helper to print contents of a single memref. Note that for the "push_back"
789-
// vectors, this prints the full capacity, not just the size. This is done
790-
// on purpose, so that clients see how much storage has been allocated in
791-
// total. Contents of the extra capacity in the buffer may be uninitialized
792-
// (unless the flag enable-buffer-initialization is set to true).
788+
// Helper to print contents of a single memref. For "push_back" vectors,
789+
// we assume that the previous getters for pos/crd/val have added a
790+
// slice-to-size view to make sure we just print the size and not the
791+
// full capacity.
793792
//
794-
// Generates code to print:
793+
// Generates code to print (1-dim or higher):
795794
// ( a0, a1, ... )
796795
static void printContents(PatternRewriter &rewriter, Location loc,
797796
Value vec) {
797+
auto shape = cast<ShapedType>(vec.getType()).getShape();
798+
SmallVector<Value> idxs;
799+
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
801+
}
802+
803+
// Helper to the helper.
804+
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805+
Value vec, unsigned i, ArrayRef<int64_t> shape,
806+
SmallVectorImpl<Value> &idxs) {
798807
// Open bracket.
799808
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
800-
// For loop over elements.
809+
// Generate for loop.
801810
auto zero = constantIndex(rewriter, loc, 0);
802-
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
811+
auto index = constantIndex(rewriter, loc, i);
812+
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
803813
auto step = constantIndex(rewriter, loc, 1);
804814
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815+
idxs.push_back(forOp.getInductionVar());
805816
rewriter.setInsertionPointToStart(forOp.getBody());
806-
auto idx = forOp.getInductionVar();
807-
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
808-
if (llvm::isa<ComplexType>(val.getType())) {
809-
// Since the vector dialect does not support complex types in any op,
810-
// we split those into (real, imag) pairs here.
811-
Value real = rewriter.create<complex::ReOp>(loc, val);
812-
Value imag = rewriter.create<complex::ImOp>(loc, val);
813-
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
814-
rewriter.create<vector::PrintOp>(loc, real,
815-
vector::PrintPunctuation::Comma);
816-
rewriter.create<vector::PrintOp>(loc, imag,
817-
vector::PrintPunctuation::Close);
818-
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
817+
if (i < shape.size() - 1) {
818+
// Enter deeper loop nest.
819+
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
819820
} else {
820-
rewriter.create<vector::PrintOp>(loc, val,
821-
vector::PrintPunctuation::Comma);
821+
// Actual contents printing.
822+
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823+
if (llvm::isa<ComplexType>(val.getType())) {
824+
// Since the vector dialect does not support complex types in any op,
825+
// we split those into (real, imag) pairs here.
826+
Value real = rewriter.create<complex::ReOp>(loc, val);
827+
Value imag = rewriter.create<complex::ImOp>(loc, val);
828+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829+
rewriter.create<vector::PrintOp>(loc, real,
830+
vector::PrintPunctuation::Comma);
831+
rewriter.create<vector::PrintOp>(loc, imag,
832+
vector::PrintPunctuation::Close);
833+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
834+
} else {
835+
rewriter.create<vector::PrintOp>(loc, val,
836+
vector::PrintPunctuation::Comma);
837+
}
822838
}
839+
idxs.pop_back();
823840
rewriter.setInsertionPointAfter(forOp);
824-
// Close bracket and end of line.
841+
// Close bracket.
825842
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
826-
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
827843
}
828844

829845
// Helper method to print run-time lvl/dim sizes.

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
crdWidth = 32
3030
}>
3131

32-
#BatchedCSR = #sparse_tensor.encoding<{
32+
#DenseCSR = #sparse_tensor.encoding<{
3333
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
3434
posWidth = 64,
3535
crdWidth = 32
@@ -42,7 +42,7 @@
4242
}>
4343

4444
//
45-
// Test assembly operation with CCC, batched-CSR and CSR-dense.
45+
// Test assembly operation with CCC, dense-CSR and CSR-dense.
4646
//
4747
module {
4848
//
@@ -77,7 +77,7 @@ module {
7777
tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
7878

7979
//
80-
// Setup BatchedCSR.
80+
// Setup DenseCSR.
8181
//
8282

8383
%data1 = arith.constant dense<
@@ -88,7 +88,7 @@ module {
8888
%crd1 = arith.constant dense<
8989
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
9090

91-
%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
91+
%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>
9292

9393
//
9494
// Setup CSRDense.
@@ -137,7 +137,7 @@ module {
137137
// CHECK-NEXT: ----
138138
//
139139
sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
140-
sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
140+
sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
141141
sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>
142142

143143
// TODO: This check is no longer needed once the codegen path uses the
@@ -148,7 +148,7 @@ module {
148148
// sparse_tensor.assemble copies buffers when running with the runtime
149149
// library. Deallocations are not needed when running in codegen mode.
150150
bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
151-
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
151+
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
152152
bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
153153
}
154154

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e main -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// TODO: make this work with libgen
21+
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
//
26+
27+
#BatchedCSR = #sparse_tensor.encoding<{
28+
map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
29+
}>
30+
31+
module {
32+
33+
//
34+
// Main driver that tests 3-D sparse tensor printing.
35+
//
36+
func.func @main() {
37+
38+
%pos = arith.constant dense<
39+
[[ 0, 8, 16, 24, 32],
40+
[ 0, 8, 16, 24, 32]]
41+
> : tensor<2x5xindex>
42+
43+
%crd = arith.constant dense<
44+
[[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
45+
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
46+
> : tensor<2x32xindex>
47+
48+
%val = arith.constant dense<
49+
[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
50+
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
51+
23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
52+
[33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
53+
44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
54+
55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
55+
> : tensor<2x32xf64>
56+
57+
%X = sparse_tensor.assemble (%pos, %crd), %val
58+
: (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>
59+
60+
// CHECK: ---- Sparse Tensor ----
61+
// CHECK-NEXT: nse = 32
62+
// CHECK-NEXT: dim = ( 2, 4, 8 )
63+
// CHECK-NEXT: lvl = ( 2, 4, 8 )
64+
// CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32, )( 0, 8, 16, 24, 32, ) )
65+
// CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, )
66+
// CHECK-SAME: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) )
67+
// CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, )
68+
// CHECK-SAME: ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, ) )
69+
// CHECK-NEXT: ----
70+
sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>
71+
72+
return
73+
}
74+
}

0 commit comments

Comments
 (0)