Skip to content

[MLIR] Determine contiguousness of memrefs with dynamic dimensions #142421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

@momchil-velikov momchil-velikov commented Jun 2, 2025

This patch enhances MemRefType::areTrailingDimsContiguous to also handle memrefs with dynamic dimensions.

The implementation itself is based on a new member function MemRefType::getMaxCollapsableTrailingDims that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.

(see also #142422)

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Momchil Velikov (momchil-velikov)

Changes

This patch enhances MemRefType::areTrailingDimsContiguous to also handle memrefs with dynamic dimensions.

The implementation itself is based on a new member function MemRefType::getMaxCollapsableTrailingDims that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.


Full diff: https://github.com/llvm/llvm-project/pull/142421.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+14)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+26-21)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+38-11)
  • (modified) mlir/unittests/Dialect/MemRef/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/MemRef/LayoutTest.cpp (+190)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///
     bool areTrailingDimsContiguous(int64_t n);
 
+    /// Return the maximum number of trailing dimensions that can be
+    /// collapsed.
+    ///
+    /// Examples:
+    ///   - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+    ///     trailing dimensions is 0
+    ///   - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+    ///     trailing dimensions is 3
+    ///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+    ///     collapsable trailing dimensions is 2.
+    ///   - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+    ///     is 4.
+    int64_t getMaxCollapsableTrailingDims();
+
     /// Return a version of this type with identity layout if it can be
     /// determined statically that the layout is the canonical contiguous
     /// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  if (!isLastDimUnitStride())
-    return false;
+  return getLayout().isIdentity() ||
+         getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
 
-  auto memrefShape = getShape().take_back(n);
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+  const int64_t n = getRank();
 
+  // memrefs with identity layout are entirely contiguous.
   if (getLayout().isIdentity())
-    return true;
+    return n;
 
+  // Get the strides (if any). Failing to do that, conservatively assume a
+  // non-contiguous layout.
   int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
-  if (strides.empty())
-    return true;
+  SmallVector<int64_t> strides;
+  if (!succeeded(getStridesAndOffset(strides, offset)))
+    return 0;
 
-  // Check whether strides match "flattened" dims.
-  SmallVector<int64_t> flattenedDims;
-  auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
-    dimProduct *= dim;
-    flattenedDims.push_back(dimProduct);
+  auto shape = getShape();
+
+  // A memref with dimensions `d0, d1, ..., dn-1` and strides
+  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+  // for `i` in `[k, n-1]`.
+  int64_t dimProduct = 1;
+  for (int64_t i = n - 1; i >= 0; --i) {
+    if (strides[i] != dimProduct)
+      return n - i - 1;
+    if (shape[i] == ShapedType::kDynamic)
+      return n - i;
+    dimProduct *= shape[i];
   }
 
-  strides = strides.drop_back(1);
-  return llvm::equal(strides, llvm::reverse(flattenedDims));
+  return n;
 }
 
 MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // One of the dims to be flattened is dynamic - not supported ATM.
 
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]
+// CHECK-SAME:    %[[IDX_2:arg1]]
+// CHECK-SAME:    %[[MEM:arg2]]
+// CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
 
 // One of the dims to be flattened is dynamic - not supported ATM.
 
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
   return
 }
 
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]: index
+// CHECK-SAME:    %[[IDX_2:arg1]]: index
+// CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
 
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LayoutTest.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // memref<2x2x2xf32, strided<[4,2,1]>
+  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+  // memref<2x2x2xf32, strided<[8,2,1]>
+  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<2x2x2xf32, strided<[8,4,1]>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x2x2xf32, strided<[8,4,2]>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<2x2x?xf32, strided<[?,?,1]>
+  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+  EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x2x?xf32, strided<[?,?,2]>
+  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+  EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<2x?x2xf32, strided<[?,2,1]>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<2x?x2xf32, strided<[?,4,1]>
+  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+  EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x?x2xf32, strided<[?,4,2]>
+  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+  EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<?x2x2xf32, strided<[4,2,1]>
+  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+  // memref<?x2x2xf32, strided<[8,2,1]>
+  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<?x2x2xf32, strided<[8,4,1]>
+  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<?x2x2xf32, strided<[8,4,2]>
+  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // memref<2x2x2xf32, strided<[4,2,1]>
+  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,2,1]>
+  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,4,1]>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,4,2]>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+  // memref<2x2x?xf32, strided<[?,?,1]>
+  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+  EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+  // memref<2x2x?xf32, strided<[?,?,2]>
+  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,2,1]>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,4,1]>
+  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+  EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,4,2]>
+  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[4,2,1]>
+  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,2,1]>
+  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,4,1]>
+  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,4,2]>
+  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+
+  // memref<2x2x2xf32>
+  auto m1 = MemRefType::get({2, 2, 2}, f32);
+  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+  // memref<?x?x?xf32>
+  auto m2 = MemRefType::get({_, _, _}, f32);
+  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good! I like the overall design, especially the new c++ testing approach. My comments are all minor.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from 1d025f3 to 1fe6866 Compare June 3, 2025 17:06
@momchil-velikov momchil-velikov requested a review from Mogball as a code owner June 3, 2025 17:06
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch 2 times, most recently from c97c6c1 to e3b6e21 Compare June 5, 2025 11:25
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for for adding the stride = 1 cases, looks good to me! I've added a few minor comments.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from e3b6e21 to ab4681a Compare June 5, 2025 16:50
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Maybe allow others with more experience here than me some time to review

This patch enhances `MemRefType::areTrailingDimsContiguous`
to also handle memrefs with dynamic dimensions.

The implementation itself is based on a new member function
`MemRefType::getMaxCollapsableTrailingDims` that return
the maximum number of trailing dimensions that can be
collapsed - trivially all dimensions for memrefs with identity
layout, or by examining the memref strides stopping at discontguous
or statically unknown strides.
`computeStrides` does not acccess the first element of `sizes`
 - rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims`
 - new set of examples
 - remove redundant call to `isIdentify()`
 - make sure a variable type is visible on the declaration line
 - some micro-optimisation
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig branch from ab4681a to 1cee446 Compare June 6, 2025 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants