-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][tensor] add tensor insert/extract op folders #142458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Asra Ali <asraa@google.com>
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (asraa) ChangesAdds a few canonicalizers, folders, and rewrite patterns to tensor ops:
Full diff: https://github.com/llvm/llvm-project/pull/142458.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index eb550bb469b9f..e8e1342ef36fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
+/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
+/// source tensor.
+void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..c0885a3763827 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 30ca20fc0d883..f2a7220b4bedc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
+#include <vector>
using namespace mlir;
using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
}
};
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
+/// tensor<12xf64>
+/// %extracted_element = tensor.extract %val[%c10] :
+/// tensor<12xf64>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
+struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+ PatternRewriter &rewriter) const final {
+ auto collapseOp =
+ extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp)
+ return failure();
+ if (!collapseOp.getSrcType().hasStaticShape())
+ return failure();
+
+ auto sourceSizes = collapseOp.getSrcType().getShape();
+
+ SmallVector<Value> indices(extractOp.getIndices().begin(),
+ extractOp.getIndices().end());
+ SmallVector<Value> sourceIndices;
+ for (auto [index, group] :
+ llvm::zip(indices, collapseOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ auto groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
+ }
+
+ SmallVector<int64_t> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
+ }
+ if (collapseOp.getReassociationIndices().empty()) {
+ auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+ int64_t srcRank =
+ cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), zeroAffineMap,
+ ArrayRef<OpFoldResult>{});
+ for (int64_t i = 0; i < srcRank; i++) {
+ sourceIndices.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, collapseOp.getSrc(), sourceIndices);
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
return success();
}
+/// If we have an ExtractOp consuming an InsertOp with the same
+/// indices, we can return the InsertOp's scalar directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsert(ExtractOp extractOp) {
+ auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
+
+ auto isSame = [](Value a, Value b) {
+ return getAsOpFoldResult(a) == getAsOpFoldResult(b);
+ };
+ if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
+ llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
+ return insertOp.getScalar();
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (Attribute tensor = adaptor.getTensor()) {
// If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return elementsAttr.getValues<Attribute>()[indices];
}
+ if (Value result = foldExtractAfterInsert(*this))
+ return result;
+
return {};
}
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractFromTensorCast>(context);
}
+void mlir::tensor::populateFoldCollapseExtractPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractFromCollapseShape>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// %c0 = arith.constant 0 : index
+/// %c4_f32 = arith.constant 4.0 : f32
+/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ // Requires a ranked tensor type.
+ auto destType =
+ llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+ if (!destType)
+ return failure();
+
+ // Pattern requires constant indices
+ SmallVector<uint64_t, 8> indices;
+ for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+ auto indiceAttr = dyn_cast<Attribute>(indice);
+ if (!indiceAttr)
+ return failure();
+ indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+ }
+
+ // Requires a constant scalar to insert
+ OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+ Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+ if (!scalarAttr)
+ return failure();
+
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+ insertOp.getDest().getDefiningOp())) {
+ if (auto sourceAttr =
+ llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+ // Update the attribute at the inserted index.
+ auto sourceValues = sourceAttr.getValues<Attribute>();
+ auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+ std::vector<Attribute> updatedValues;
+ updatedValues.reserve(sourceAttr.getNumElements());
+ for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
+ updatedValues.push_back(i == flattenedIndex ? scalarAttr
+ : sourceValues[i]);
+ }
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ insertOp, sourceAttr.getType(),
+ DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
+ updatedValues));
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+} // namespace
+
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<InsertOpConstantFold>(context);
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index cdcd7f305d2d9..0abec7e01d184 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
// -----
// CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
- // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
- return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+ // Fold an extract after an insert.
+ // CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
+ %inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
+ %ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
+
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
}
// -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
+
+// -----
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+ // Fold an insert into a splat.
+ // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+ // CHECK-LITERAL:
+ // CHECK-NEXT: return %[[C4]]
+ %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4_i32 = arith.constant 4 : i32
+ %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+ return %inserted : tensor<2x2xi32>
+}
+
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
new file mode 100644
index 0000000000000..c301f494a7c87
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
+
+// CHECK-LABEL: @extract_from_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
+func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
+ %extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
+ %extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
+ func.return %extracted, %extracted_0 : i8, i8
+}
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
+// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
+
+// -----
+
+// CHECK-LABEL: @extract_from_static_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
+ %1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index e435130c2a417..0e191c32f009e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testFoldExtractFromCollapseShape{
+ *this, "test-fold-extract-from-collapse-shape",
+ llvm::cl::desc("Test folding of extract from collapse_shape"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
+static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldCollapseExtractPatterns(patterns);
+ (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
+ if (testFoldExtractFromCollapseShape)
+ applyFoldExtractFromCollapseShapePatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Will wait for tests to pass and then merge
I have some concerns about folders in the tensor dialect, in particular:
This kind of folders can quickly lead to a size explosion because of the nature of tensor constants which can be pretty large. |
return failure(); | ||
|
||
// Pattern requires constant indices | ||
SmallVector<uint64_t, 8> indices; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<uint64_t> indices;
Nit: don't use a specific size for SmallVector without a good reasons to.
// Update the attribute at the inserted index. | ||
auto sourceValues = sourceAttr.getValues<Attribute>(); | ||
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices); | ||
std::vector<Attribute> updatedValues; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is an unfortunate slow path for what will be converted ultimately to a vector of int or float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting to copy the source values instead of looping? Or to switch on the int / float type and then convert back to attributes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm saying we should be able to avoid using individual attribues per-element in the common case indeed.
We may be lacking (possibly templated?) helpers to do this conveniently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah in retrospect it makes sense to type switch on the attribute and the options are DenseIntElementsAttr and DenseFPElementsAttr.
…f canonicalization (#142671) Follow ups from llvm/llvm-project#142458 In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large. Signed-off-by: Asra Ali <asraa@google.com>
Adds a few canonicalizers, folders, and rewrite patterns to tensor ops: