Skip to content

Commit

Permalink
[mlir][sparse] ensure [dis]assembler wrapper methods properly inline (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
aartbik authored Feb 15, 2024
1 parent 3e004d1 commit 4d273b9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
}

// Convert input and output values to [dis[assemble ops for sparse tensors.
// Convert input and output values to [dis]assemble ops for sparse tensors.
void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
Expand Down Expand Up @@ -161,8 +161,6 @@ namespace {
//
// TODO: refine output sparse tensors to work well with external framework
//
// TODO: use "inlining" instead of a wrapper?
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, /*isIn=*/true);

// Call original, now internal method.
// Call the original, now private method. A subsequent inlining pass can
// determine whether cloning the method body in place is worthwhile.
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
Expand Down
10 changes: 9 additions & 1 deletion mlir/test/Dialect/SparseTensor/torch_linalg.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
// RUN: mlir-opt %s --sparse-assembler \
// RUN: --inline | FileCheck %s --check-prefix=CHECK-INL
// RUN: mlir-opt %s --sparse-assembler \
// RUN: --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
Expand All @@ -20,7 +22,13 @@
// CHECK-HI: func.func private @_internal_main
// CHECK-HI: linalg.matmul
// CHECK-HI: return
//

// CHECK-INL-LABEL: func.func @main
// CHECK-INL: sparse_tensor.assemble
// CHECK-INL: linalg.matmul
// CHECK-INL: return
// CHECK-INL-NOT: func.func private @_internal_main

// CHECK-MID-LABEL: func.func @main
// CHECK-MID: memref.load
// CHECK-MID: call @_internal_main
Expand Down

0 comments on commit 4d273b9

Please sign in to comment.