-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Determine contiguousness of memrefs with dynamic dimensions #142421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-ods Author: Momchil Velikov (momchil-velikov) ChangesThis patch enhances The implementation itself is based on a new member function Full diff: https://github.com/llvm/llvm-project/pull/142421.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
+ /// Return the maximum number of trailing dimensions that can be
+ /// collapsed.
+ ///
+ /// Examples:
+ /// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+ /// trailing dimensions is 0
+ /// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+ /// trailing dimensions is 3
+ /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+ /// collapsable trailing dimensions is 2.
+ /// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+ /// is 4.
+ int64_t getMaxCollapsableTrailingDims();
+
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- if (!isLastDimUnitStride())
- return false;
+ return getLayout().isIdentity() ||
+ getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
- auto memrefShape = getShape().take_back(n);
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+ const int64_t n = getRank();
+ // memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
- return true;
+ return n;
+ // Get the strides (if any). Failing to do that, conservatively assume a
+ // non-contiguous layout.
int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
- if (strides.empty())
- return true;
+ SmallVector<int64_t> strides;
+ if (!succeeded(getStridesAndOffset(strides, offset)))
+ return 0;
- // Check whether strides match "flattened" dims.
- SmallVector<int64_t> flattenedDims;
- auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
- dimProduct *= dim;
- flattenedDims.push_back(dimProduct);
+ auto shape = getShape();
+
+ // A memref with dimensions `d0, d1, ..., dn-1` and strides
+ // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+ // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+ // for `i` in `[k, n-1]`.
+ int64_t dimProduct = 1;
+ for (int64_t i = n - 1; i >= 0; --i) {
+ if (strides[i] != dimProduct)
+ return n - i - 1;
+ if (shape[i] == ShapedType::kDynamic)
+ return n - i;
+ dimProduct *= shape[i];
}
- strides = strides.drop_back(1);
- return llvm::equal(strides, llvm::reverse(flattenedDims));
+ return n;
}
MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]
+// CHECK-SAME: %[[IDX_2:arg1]]
+// CHECK-SAME: %[[MEM:arg2]]
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
return
}
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRMemRefTests
InferShapeTest.cpp
+ LayoutTest.cpp
)
mlir_target_link_libraries(MLIRMemRefTests
PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+
+ // memref<2x2x2xf32>
+ auto m1 = MemRefType::get({2, 2, 2}, f32);
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<?x?x?xf32>
+ auto m2 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good! I like the overall design, especially the new c++ testing approach. My comments are all minor.
1d025f3
to
1fe6866
Compare
c97c6c1
to
e3b6e21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank for for adding the stride = 1 cases, looks good to me! I've added a few minor comments.
e3b6e21
to
ab4681a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Maybe allow others with more experience here than me some time to review
This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions. The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontguous or statically unknown strides.
`computeStrides` does not acccess the first element of `sizes`
- rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims` - new set of examples - remove redundant call to `isIdentify()` - make sure a variable type is visible on the declaration line - some micro-optimisation
ab4681a
to
1cee446
Compare
This patch enhances
MemRefType::areTrailingDimsContiguous
to also handle memrefs with dynamic dimensions.The implementation itself is based on a new member function
MemRefType::getMaxCollapsableTrailingDims
that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.(see also #142422)