Skip to content

[mlir][memref][NFC] Simplify constifyIndexValues #135940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Apr 16, 2025

Simplify the code by removing function pointers.

Depends on #135939.

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

Changes

Depends on #135939.


Full diff: https://github.com/llvm/llvm-project/pull/135940.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+50-103)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 63f5251398716..92f44c97ee5d7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
 // Utility functions for propagating static information
 //===----------------------------------------------------------------------===//
 
-/// Helper function that infers the constant values from a list of \p values,
-/// a \p memRefTy, and another helper function \p getAttributes.
-/// The inferred constant values replace the related `OpFoldResult` in
-/// \p values.
+/// Helper function that sets values[i] to constValues[i] if the latter is a
+/// static value, as indicated by ShapedType::kDynamic.
 ///
-/// \note This function shouldn't be used directly, instead, use the
-/// `getConstifiedMixedXXX` methods from the related operations.
-///
-/// \p getAttributes retuns a list of potentially constant values, as determined
-/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
-/// many elements as \p values or be empty.
-///
-/// E.g., consider the following example:
-/// ```
-/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
-///     memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-/// ```
-/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
-/// Now using this helper function with:
-/// - `values == [2, %dyn_stride]`,
-/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
-/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
-/// `getStridesAndOffset`), and
-/// - `isDynamic == ShapedType::isDynamic`
-/// Will yield: `values == [2, 1]`
-static void constifyIndexValues(
-    SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
-    MLIRContext *ctxt,
-    llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
-    llvm::function_ref<bool(int64_t)> isDynamic) {
-  SmallVector<int64_t> constValues = getAttributes(memRefTy);
-  Builder builder(ctxt);
-  for (const auto &it : llvm::enumerate(constValues)) {
-    int64_t constValue = it.value();
-    if (!isDynamic(constValue))
-      values[it.index()] = builder.getIndexAttr(constValue);
-  }
-  for (OpFoldResult &ofr : values) {
-    if (auto attr = dyn_cast<Attribute>(ofr)) {
-      // FIXME: We shouldn't need to do that, but right now, the static indices
-      // are created with the wrong type: `i64` instead of `index`.
-      // As a result, if we were to keep the attribute as is, we may fail to see
-      // that two attributes are equal because one would have the i64 type and
-      // the other the index type.
-      // The alternative would be to create constant indices with getI64Attr in
-      // this and the previous loop, but it doesn't logically make sense (we are
-      // dealing with indices here) and would only strenghten the inconsistency
-      // around how static indices are created (some places use getI64Attr,
-      // others use getIndexAttr).
-      // The workaround here is to stick to the IndexAttr type for all the
-      // values, hence we recreate the attribute even when it is already static
-      // to make sure the type is consistent.
-      ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
+/// If constValues[i] is dynamic, tries to extract a constant value from
+/// value[i] to allow for additional folding opportunities. Also convertes all
+/// existing attributes to index attributes. (They may be i64 attributes.)
+static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
+                                ArrayRef<int64_t> constValues) {
+  assert(constValues.size() == values.size() &&
+         "incorrect number of const values");
+  for (int64_t i = 0, e = constValues.size(); i < e; ++i) {
+    Builder builder(values[i].getContext());
+    if (!ShapedType::isDynamic(constValues[i])) {
+      // Constant value is known, use it directly.
+      values[i] = builder.getIndexAttr(constValues[i]);
       continue;
     }
-    std::optional<int64_t> maybeConstant =
-        getConstantIntValue(cast<Value>(ofr));
-    if (maybeConstant)
-      ofr = builder.getIndexAttr(*maybeConstant);
+    if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
+      // Try to extract a constant or convert an existing to index.
+      values[i] = builder.getIndexAttr(*cst);
+    }
   }
 }
 
-/// Wrapper around `getShape` that conforms to the function signature
-/// expected for `getAttributes` in `constifyIndexValues`.
-static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
-  ArrayRef<int64_t> sizes = memRefTy.getShape();
-  return SmallVector<int64_t>(sizes);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the offset and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  LogicalResult hasStaticInformation =
-      memrefType.getStridesAndOffset(strides, offset);
-  if (failed(hasStaticInformation))
-    return SmallVector<int64_t>();
-  return SmallVector<int64_t>(1, offset);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the strides and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  LogicalResult hasStaticInformation =
-      memrefType.getStridesAndOffset(strides, offset);
-  if (failed(hasStaticInformation))
-    return SmallVector<int64_t>();
-  return strides;
-}
-
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -1124,7 +1053,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    } // Check condition 2
+    }   // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
 
 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantSizes, ShapedType::isDynamic);
+  constifyIndexValues(values, getSource().getType().getShape());
   return values;
 }
 
 SmallVector<OpFoldResult>
 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
   SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantStrides, ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues;
+  int64_t unused;
+  LogicalResult status =
+      getSource().getType().getStridesAndOffset(staticValues, unused);
+  (void)status;
+  assert(succeeded(status) && "could not get strides from type");
+  constifyIndexValues(values, staticValues);
   return values;
 }
 
 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
   OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
   SmallVector<OpFoldResult> values(1, offsetOfr);
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantOffset, ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues, unused;
+  int64_t offset;
+  LogicalResult status =
+      getSource().getType().getStridesAndOffset(unused, offset);
+  (void)status;
+  assert(succeeded(status) && "could not get offset from type");
+  staticValues.push_back(offset);
+  constifyIndexValues(values, staticValues);
   return values[0];
 }
 
@@ -1975,15 +1914,18 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
 
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getMixedSizes();
-  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
-                      ShapedType::isDynamic);
+  constifyIndexValues(values, getType().getShape());
   return values;
 }
 
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
   SmallVector<OpFoldResult> values = getMixedStrides();
-  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
-                      ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues;
+  int64_t unused;
+  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
+  (void)status;
+  assert(succeeded(status) && "could not get strides from type");
+  constifyIndexValues(values, staticValues);
   return values;
 }
 
@@ -1991,8 +1933,13 @@ OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
   SmallVector<OpFoldResult> values = getMixedOffsets();
   assert(values.size() == 1 &&
          "reinterpret_cast must have one and only one offset");
-  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
-                      ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues, unused;
+  int64_t offset;
+  LogicalResult status = getType().getStridesAndOffset(unused, offset);
+  (void)status;
+  assert(succeeded(status) && "could not get offset from type");
+  staticValues.push_back(offset);
+  constifyIndexValues(values, staticValues);
   return values[0];
 }
 
@@ -2062,7 +2009,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // Second, check the sizes.
       if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
                        op.getConstifiedMixedSizes()))
-          return false;
+        return false;
 
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&

Copy link

github-actions bot commented Apr 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_constify branch 3 times, most recently from d137ec0 to 05d2c7b Compare April 16, 2025 09:11
Base automatically changed from users/matthias-springer/reinterpret_cast_strided to main April 17, 2025 06:48
@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_constify branch from 05d2c7b to b2beb0c Compare April 17, 2025 06:51
@matthias-springer matthias-springer merged commit 62d32c2 into main Apr 17, 2025
7 of 11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/simplify_constify branch April 17, 2025 06:58
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
Simplify the code by removing function pointers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants