-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][memref][NFC] Simplify constifyIndexValues
#135940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
matthias-springer
merged 2 commits into
main
from
users/matthias-springer/simplify_constify
Apr 17, 2025
Merged
[mlir][memref][NFC] Simplify constifyIndexValues
#135940
matthias-springer
merged 2 commits into
main
from
users/matthias-springer/simplify_constify
Apr 17, 2025
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesDepends on #135939. Full diff: https://github.com/llvm/llvm-project/pull/135940.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 63f5251398716..92f44c97ee5d7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
// Utility functions for propagating static information
//===----------------------------------------------------------------------===//
-/// Helper function that infers the constant values from a list of \p values,
-/// a \p memRefTy, and another helper function \p getAttributes.
-/// The inferred constant values replace the related `OpFoldResult` in
-/// \p values.
+/// Helper function that sets values[i] to constValues[i] if the latter is a
+/// static value, as indicated by ShapedType::kDynamic.
///
-/// \note This function shouldn't be used directly, instead, use the
-/// `getConstifiedMixedXXX` methods from the related operations.
-///
-/// \p getAttributes retuns a list of potentially constant values, as determined
-/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
-/// many elements as \p values or be empty.
-///
-/// E.g., consider the following example:
-/// ```
-/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
-/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-/// ```
-/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
-/// Now using this helper function with:
-/// - `values == [2, %dyn_stride]`,
-/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
-/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
-/// `getStridesAndOffset`), and
-/// - `isDynamic == ShapedType::isDynamic`
-/// Will yield: `values == [2, 1]`
-static void constifyIndexValues(
- SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
- MLIRContext *ctxt,
- llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- SmallVector<int64_t> constValues = getAttributes(memRefTy);
- Builder builder(ctxt);
- for (const auto &it : llvm::enumerate(constValues)) {
- int64_t constValue = it.value();
- if (!isDynamic(constValue))
- values[it.index()] = builder.getIndexAttr(constValue);
- }
- for (OpFoldResult &ofr : values) {
- if (auto attr = dyn_cast<Attribute>(ofr)) {
- // FIXME: We shouldn't need to do that, but right now, the static indices
- // are created with the wrong type: `i64` instead of `index`.
- // As a result, if we were to keep the attribute as is, we may fail to see
- // that two attributes are equal because one would have the i64 type and
- // the other the index type.
- // The alternative would be to create constant indices with getI64Attr in
- // this and the previous loop, but it doesn't logically make sense (we are
- // dealing with indices here) and would only strenghten the inconsistency
- // around how static indices are created (some places use getI64Attr,
- // others use getIndexAttr).
- // The workaround here is to stick to the IndexAttr type for all the
- // values, hence we recreate the attribute even when it is already static
- // to make sure the type is consistent.
- ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
+/// If constValues[i] is dynamic, tries to extract a constant value from
+/// value[i] to allow for additional folding opportunities. Also convertes all
+/// existing attributes to index attributes. (They may be i64 attributes.)
+static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
+ ArrayRef<int64_t> constValues) {
+ assert(constValues.size() == values.size() &&
+ "incorrect number of const values");
+ for (int64_t i = 0, e = constValues.size(); i < e; ++i) {
+ Builder builder(values[i].getContext());
+ if (!ShapedType::isDynamic(constValues[i])) {
+ // Constant value is known, use it directly.
+ values[i] = builder.getIndexAttr(constValues[i]);
continue;
}
- std::optional<int64_t> maybeConstant =
- getConstantIntValue(cast<Value>(ofr));
- if (maybeConstant)
- ofr = builder.getIndexAttr(*maybeConstant);
+ if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
+ // Try to extract a constant or convert an existing to index.
+ values[i] = builder.getIndexAttr(*cst);
+ }
}
}
-/// Wrapper around `getShape` that conforms to the function signature
-/// expected for `getAttributes` in `constifyIndexValues`.
-static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
- ArrayRef<int64_t> sizes = memRefTy.getShape();
- return SmallVector<int64_t>(sizes);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the offset and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
- SmallVector<int64_t> strides;
- int64_t offset;
- LogicalResult hasStaticInformation =
- memrefType.getStridesAndOffset(strides, offset);
- if (failed(hasStaticInformation))
- return SmallVector<int64_t>();
- return SmallVector<int64_t>(1, offset);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the strides and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
- SmallVector<int64_t> strides;
- int64_t offset;
- LogicalResult hasStaticInformation =
- memrefType.getStridesAndOffset(strides, offset);
- if (failed(hasStaticInformation))
- return SmallVector<int64_t>();
- return strides;
-}
-
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -1124,7 +1053,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
- constifyIndexValues(values, getSource().getType(), getContext(),
- getConstantSizes, ShapedType::isDynamic);
+ constifyIndexValues(values, getSource().getType().getShape());
return values;
}
SmallVector<OpFoldResult>
ExtractStridedMetadataOp::getConstifiedMixedStrides() {
SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
- constifyIndexValues(values, getSource().getType(), getContext(),
- getConstantStrides, ShapedType::isDynamic);
+ SmallVector<int64_t> staticValues;
+ int64_t unused;
+ LogicalResult status =
+ getSource().getType().getStridesAndOffset(staticValues, unused);
+ (void)status;
+ assert(succeeded(status) && "could not get strides from type");
+ constifyIndexValues(values, staticValues);
return values;
}
OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
SmallVector<OpFoldResult> values(1, offsetOfr);
- constifyIndexValues(values, getSource().getType(), getContext(),
- getConstantOffset, ShapedType::isDynamic);
+ SmallVector<int64_t> staticValues, unused;
+ int64_t offset;
+ LogicalResult status =
+ getSource().getType().getStridesAndOffset(unused, offset);
+ (void)status;
+ assert(succeeded(status) && "could not get offset from type");
+ staticValues.push_back(offset);
+ constifyIndexValues(values, staticValues);
return values[0];
}
@@ -1975,15 +1914,18 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getMixedSizes();
- constifyIndexValues(values, getType(), getContext(), getConstantSizes,
- ShapedType::isDynamic);
+ constifyIndexValues(values, getType().getShape());
return values;
}
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
SmallVector<OpFoldResult> values = getMixedStrides();
- constifyIndexValues(values, getType(), getContext(), getConstantStrides,
- ShapedType::isDynamic);
+ SmallVector<int64_t> staticValues;
+ int64_t unused;
+ LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
+ (void)status;
+ assert(succeeded(status) && "could not get strides from type");
+ constifyIndexValues(values, staticValues);
return values;
}
@@ -1991,8 +1933,13 @@ OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
SmallVector<OpFoldResult> values = getMixedOffsets();
assert(values.size() == 1 &&
"reinterpret_cast must have one and only one offset");
- constifyIndexValues(values, getType(), getContext(), getConstantOffset,
- ShapedType::isDynamic);
+ SmallVector<int64_t> staticValues, unused;
+ int64_t offset;
+ LogicalResult status = getType().getStridesAndOffset(unused, offset);
+ (void)status;
+ assert(succeeded(status) && "could not get offset from type");
+ staticValues.push_back(offset);
+ constifyIndexValues(values, staticValues);
return values[0];
}
@@ -2062,7 +2009,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
// Second, check the sizes.
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
op.getConstifiedMixedSizes()))
- return false;
+ return false;
// Finally, check the offset.
assert(op.getMixedOffsets().size() == 1 &&
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
d137ec0
to
05d2c7b
Compare
qcolombet
reviewed
Apr 16, 2025
qcolombet
approved these changes
Apr 16, 2025
Base automatically changed from
users/matthias-springer/reinterpret_cast_strided
to
main
April 17, 2025 06:48
05d2c7b
to
b2beb0c
Compare
var-const
pushed a commit
to ldionne/llvm-project
that referenced
this pull request
Apr 17, 2025
Simplify the code by removing function pointers.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Simplify the code by removing function pointers.
Depends on #135939.