From 4d273b948ef064230091e41cf81f4c1b91d5beb4 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 15 Feb 2024 11:39:32 -0800 Subject: [PATCH] [mlir][sparse] ensure [dis]assembler wrapper methods properly inline (#81907) --- .../SparseTensor/Transforms/SparseAssembler.cpp | 7 +++---- mlir/test/Dialect/SparseTensor/torch_linalg.mlir | 10 +++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index 98f9d15d09fa3..9414d81e6bf5c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl &convTypes, } } -// Convert input and output values to [dis[assemble ops for sparse tensors. +// Convert input and output values to [dis]assemble ops for sparse tensors. void convVals(OpBuilder &builder, Location loc, TypeRange types, ValueRange fromVals, ValueRange extraVals, SmallVectorImpl &toVals, unsigned extra, bool isIn) { @@ -161,8 +161,6 @@ namespace { // // TODO: refine output sparse tensors to work well with external framework // -// TODO: use "inlining" instead of a wrapper? -// struct SparseFuncAssembler : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern { convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), ValueRange(), inputs, 0, /*isIn=*/true); - // Call original, now internal method. + // Call the original, now private method. A subsequent inlining pass can + // determine whether cloning the method body in place is worthwhile. auto org = SymbolRefAttr::get(context, wrapper); auto call = rewriter.create(loc, funcOp.getResultTypes(), org, inputs); diff --git a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir index f29e6b143783a..4bb5938b2e44e 100644 --- a/mlir/test/Dialect/SparseTensor/torch_linalg.mlir +++ b/mlir/test/Dialect/SparseTensor/torch_linalg.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI // RUN: mlir-opt %s --sparse-assembler \ +// RUN: --inline | FileCheck %s --check-prefix=CHECK-INL +// RUN: mlir-opt %s --sparse-assembler \ // RUN: --linalg-generalize-named-ops \ // RUN: --linalg-fuse-elementwise-ops \ // RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID @@ -20,7 +22,13 @@ // CHECK-HI: func.func private @_internal_main // CHECK-HI: linalg.matmul // CHECK-HI: return -// + +// CHECK-INL-LABEL: func.func @main +// CHECK-INL: sparse_tensor.assemble +// CHECK-INL: linalg.matmul +// CHECK-INL: return +// CHECK-INL-NOT: func.func private @_internal_main + // CHECK-MID-LABEL: func.func @main // CHECK-MID: memref.load // CHECK-MID: call @_internal_main