Skip to content

[mlir][sparse] support 'batch' dimensions in sparse_tensor.print #91411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
/// Generates a subview into the sizes.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
Value sz) {
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
auto memTp = llvm::cast<MemRefType>(mem.getType());
// For higher-dimensional memrefs, we assume that the innermost
// dimension is always of the right size.
// TODO: generate complex truncating view here too?
if (memTp.getRank() > 1)
return mem;
// Truncate linear memrefs to given size.
return builder
.create<memref::SubViewOp>(
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
ValueRange{}, ValueRange{sz}, ValueRange{},
loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
mem, ValueRange{}, ValueRange{sz}, ValueRange{},
ArrayRef<int64_t>{0}, // static offset
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
ArrayRef<int64_t>{1}) // static stride
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,45 +785,61 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
}

private:
// Helper to print contents of a single memref. Note that for the "push_back"
// vectors, this prints the full capacity, not just the size. This is done
// on purpose, so that clients see how much storage has been allocated in
// total. Contents of the extra capacity in the buffer may be uninitialized
// (unless the flag enable-buffer-initialization is set to true).
// Helper to print contents of a single memref. For "push_back" vectors,
// we assume that the previous getters for pos/crd/val have added a
// slice-to-size view to make sure we just print the size and not the
// full capacity.
//
// Generates code to print:
// Generates code to print (1-dim or higher):
// ( a0, a1, ... )
static void printContents(PatternRewriter &rewriter, Location loc,
Value vec) {
auto shape = cast<ShapedType>(vec.getType()).getShape();
SmallVector<Value> idxs;
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}

// Helper to the helper.
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
Value vec, unsigned i, ArrayRef<int64_t> shape,
SmallVectorImpl<Value> &idxs) {
// Open bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
// For loop over elements.
// Generate for loop.
auto zero = constantIndex(rewriter, loc, 0);
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
auto index = constantIndex(rewriter, loc, i);
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
auto step = constantIndex(rewriter, loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
idxs.push_back(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
auto idx = forOp.getInductionVar();
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
if (llvm::isa<ComplexType>(val.getType())) {
// Since the vector dialect does not support complex types in any op,
// we split those into (real, imag) pairs here.
Value real = rewriter.create<complex::ReOp>(loc, val);
Value imag = rewriter.create<complex::ImOp>(loc, val);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
rewriter.create<vector::PrintOp>(loc, real,
vector::PrintPunctuation::Comma);
rewriter.create<vector::PrintOp>(loc, imag,
vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
if (i < shape.size() - 1) {
// Enter deeper loop nest.
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
} else {
rewriter.create<vector::PrintOp>(loc, val,
vector::PrintPunctuation::Comma);
// Actual contents printing.
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
if (llvm::isa<ComplexType>(val.getType())) {
// Since the vector dialect does not support complex types in any op,
// we split those into (real, imag) pairs here.
Value real = rewriter.create<complex::ReOp>(loc, val);
Value imag = rewriter.create<complex::ImOp>(loc, val);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
rewriter.create<vector::PrintOp>(loc, real,
vector::PrintPunctuation::Comma);
rewriter.create<vector::PrintOp>(loc, imag,
vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
} else {
rewriter.create<vector::PrintOp>(loc, val,
vector::PrintPunctuation::Comma);
}
}
idxs.pop_back();
rewriter.setInsertionPointAfter(forOp);
// Close bracket and end of line.
// Close bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}

// Helper method to print run-time lvl/dim sizes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
crdWidth = 32
}>

#BatchedCSR = #sparse_tensor.encoding<{
#DenseCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
posWidth = 64,
crdWidth = 32
Expand All @@ -42,7 +42,7 @@
}>

//
// Test assembly operation with CCC, batched-CSR and CSR-dense.
// Test assembly operation with CCC, dense-CSR and CSR-dense.
//
module {
//
Expand Down Expand Up @@ -77,7 +77,7 @@ module {
tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>

//
// Setup BatchedCSR.
// Setup DenseCSR.
//

%data1 = arith.constant dense<
Expand All @@ -88,7 +88,7 @@ module {
%crd1 = arith.constant dense<
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>

%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>

//
// Setup CSRDense.
Expand Down Expand Up @@ -137,7 +137,7 @@ module {
// CHECK-NEXT: ----
//
sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>

// TODO: This check is no longer needed once the codegen path uses the
Expand All @@ -148,7 +148,7 @@ module {
// sparse_tensor.assemble copies buffers when running with the runtime
// library. Deallocations are not needed when running in codegen mode.
bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//--------------------------------------------------------------------------------------------------
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
//
// Set-up that's shared across all tests in this directory. In principle, this
// config could be moved to lit.local.cfg. However, there are downstream users that
// do not use these LIT config files. Hence why this is kept inline.
//
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
// DEFINE: %{run_opts} = -e main -entry-point-result=void
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
//
// DEFINE: %{env} =
//--------------------------------------------------------------------------------------------------

// TODO: make this work with libgen

// Do the same run, but now with direct IR generation.
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
// RUN: %{compile} | %{run} | FileCheck %s
//

#BatchedCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
}>

module {

//
// Main driver that tests 3-D sparse tensor printing.
//
func.func @main() {

%pos = arith.constant dense<
[[ 0, 8, 16, 24, 32],
[ 0, 8, 16, 24, 32]]
> : tensor<2x5xindex>

%crd = arith.constant dense<
[[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
> : tensor<2x32xindex>

%val = arith.constant dense<
[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
[33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
> : tensor<2x32xf64>

%X = sparse_tensor.assemble (%pos, %crd), %val
: (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>

// CHECK: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 32
// CHECK-NEXT: dim = ( 2, 4, 8 )
// CHECK-NEXT: lvl = ( 2, 4, 8 )
// CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32, )( 0, 8, 16, 24, 32, ) )
// CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, )
// CHECK-SAME: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) )
// CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, )
// CHECK-SAME: ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, ) )
// CHECK-NEXT: ----
sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>

return
}
}
Loading