-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Affine] Generalize the linearize(delinearize()) simplifications #117637
[mlir][Affine] Generalize the linearize(delinearize()) simplifications #117637
Conversation
@llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesThe existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds). This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work. This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created. Therefore, this commit generalizes the existing simplification logic. Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when That is, we can now simplify
to the simpler
This new pattern also works with dynamically-sized basis values. While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117637.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..0d1e4ede795ce5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%indices_2 = affine.apply #map2()[%linear_index]
```
+ In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+ `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
- Note that, due to the constraints of affine maps, all the basis elements must
+ Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+ when one of the `OpFoldResult` builders is called but the first element of the
+ basis is `nullptr`, that first element is ignored and the builder proceeds as if
+ there was no outer bound.
+
+ Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
}];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per result of the operation. If
+ /// there is no outer bound specified, the leading entry of this result will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+ gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
+ As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ element of a set of `OpFoldResult`s passed to the builders of this operation is
+ `nullptr`, that element is ignored.
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per index operand of the operation.
+ /// If there is no outer bound specified, the leading entry of this basis will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..b7c5e8eff8a8cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
Value linearIndex,
ArrayRef<OpFoldResult> basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4614,6 +4622,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4632,25 +4647,27 @@ struct DropUnitExtentBasis
return zero.value();
};
- bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<OpFoldResult> newBasis;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ for (auto [index, basis] :
+ llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+ std::optional<int64_t> basisVal =
+ basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
- replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+ replacements[index] = getZero();
else
newBasis.push_back(basis);
}
- if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getNumResults())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newBasis.empty() || !hasOuterBound) {
+ if (!newBasis.empty()) {
+ // Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+ loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4831,6 +4848,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == Value())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4843,6 +4862,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == OpFoldResult())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4918,7 +4939,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
builder);
}
- return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
}
namespace {
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> seen;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned firstDelinArg = asResult.getResultNumber();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+ OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+ OpFoldResult firstLinBound = linBasis[linArgIdx];
+ bool boundsMatch = firstDelinBound == firstLinBound;
+ bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+ bool knownByDisjoint =
+ linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+ if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+ linArgIdx++;
+ continue;
+ }
- if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ unsigned j = 1;
+ unsigned numDelinOuts = delinearizeOp.getNumResults();
+ for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+ ++j) {
+ if (multiIndex[linArgIdx + j] !=
+ delinearizeOp.getResult(firstDelinArg + j))
+ break;
+ if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+ break;
+ }
+ // If there're multiple matches against the same delinearize_index,
+ // only rewrite the first one we find to prevent invalidations. The next
+ // ones will be taken caer of by subsequent pattern invocations.
+ if (j <= 1 || !seen.insert(delinearizeOp).second) {
+ linArgIdx++;
+ continue;
+ }
+ matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+ linArgIdx += j;
+ }
+
+ if (matches.empty())
return rewriter.notifyMatchFailure(
- linearizeOp, "not all indices come from delinearize");
+ linearizeOp, "no run of delinearize outputs to deal with");
+
+ SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+ SmallVector<Value> newIndex;
+ newIndex.reserve(numLinArgs);
+ SmallVector<OpFoldResult> newBasis;
+ newBasis.reserve(numLinArgs);
+ unsigned prevMatchEnd = 0;
+ for (Match m : matches) {
+ unsigned gap = m.linStart - prevMatchEnd;
+ llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+ llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+ // Update here so we don't forget this during early continues
+ prevMatchEnd = m.linStart + m.length;
+
+ // We use the slice from the linearize's basis above because of the
+ // "bounds inferred from `disjoint`" case above.
+ OpFoldResult newSize =
+ computeProduct(linearizeOp.getLoc(), rewriter,
+ linBasisRef.slice(m.linStart, m.length));
+
+ // Trivial case where we can just skip past the delinearize all together
+ if (m.length == m.delinearize.getNumResults()) {
+ newIndex.push_back(m.delinearize.getLinearIndex());
+ newBasis.push_back(newSize);
+ continue;
+ }
+ SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+ newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+ newDelinBasis.begin() + m.delinStart + m.length);
+ newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+ auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ newDelinBasis);
+
+ // Swap all the uses of the unaffected delinearize outputs to the new
+ // delinearization so that the old code can be removed if this
+ // linearize_index is the only user of the merged results.
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().take_front(m.delinStart),
+ newDelinearize.getResults().take_front(m.delinStart)));
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().drop_front(m.delinStart + m.length),
+ newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+ Value newLinArg = newDelinearize.getResult(m.delinStart);
+ newIndex.push_back(newLinArg);
+ newBasis.push_back(newSize);
+ }
+ llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+ llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+ rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+ linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
- rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ for (auto [from, to] : delinearizeReplacements)
+ rewriter.replaceAllUsesWith(from, to);
return success();
}
};
@@ -5049,7 +5241,7 @@ struct DropLinearizeLeadingZero final
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..c153d32670d574 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThe existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds). This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work. This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created. Therefore, this commit generalizes the existing simplification logic. Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when That is, we can now simplify
to the simpler
This new pattern also works with dynamically-sized basis values. While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117637.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..0d1e4ede795ce5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%indices_2 = affine.apply #map2()[%linear_index]
```
+ In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+ `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
- Note that, due to the constraints of affine maps, all the basis elements must
+ Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+ when one of the `OpFoldResult` builders is called but the first element of the
+ basis is `nullptr`, that first element is ignored and the builder proceeds as if
+ there was no outer bound.
+
+ Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
}];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per result of the operation. If
+ /// there is no outer bound specified, the leading entry of this result will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+ gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
+ As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ element of a set of `OpFoldResult`s passed to the builders of this operation is
+ `nullptr`, that element is ignored.
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per index operand of the operation.
+ /// If there is no outer bound specified, the leading entry of this basis will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..b7c5e8eff8a8cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
Value linearIndex,
ArrayRef<OpFoldResult> basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4614,6 +4622,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4632,25 +4647,27 @@ struct DropUnitExtentBasis
return zero.value();
};
- bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<OpFoldResult> newBasis;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ for (auto [index, basis] :
+ llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+ std::optional<int64_t> basisVal =
+ basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
- replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+ replacements[index] = getZero();
else
newBasis.push_back(basis);
}
- if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getNumResults())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newBasis.empty() || !hasOuterBound) {
+ if (!newBasis.empty()) {
+ // Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+ loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4831,6 +4848,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == Value())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4843,6 +4862,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == OpFoldResult())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4918,7 +4939,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
builder);
}
- return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
}
namespace {
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> seen;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned firstDelinArg = asResult.getResultNumber();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+ OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+ OpFoldResult firstLinBound = linBasis[linArgIdx];
+ bool boundsMatch = firstDelinBound == firstLinBound;
+ bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+ bool knownByDisjoint =
+ linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+ if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+ linArgIdx++;
+ continue;
+ }
- if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ unsigned j = 1;
+ unsigned numDelinOuts = delinearizeOp.getNumResults();
+ for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+ ++j) {
+ if (multiIndex[linArgIdx + j] !=
+ delinearizeOp.getResult(firstDelinArg + j))
+ break;
+ if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+ break;
+ }
+ // If there're multiple matches against the same delinearize_index,
+ // only rewrite the first one we find to prevent invalidations. The next
+ // ones will be taken caer of by subsequent pattern invocations.
+ if (j <= 1 || !seen.insert(delinearizeOp).second) {
+ linArgIdx++;
+ continue;
+ }
+ matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+ linArgIdx += j;
+ }
+
+ if (matches.empty())
return rewriter.notifyMatchFailure(
- linearizeOp, "not all indices come from delinearize");
+ linearizeOp, "no run of delinearize outputs to deal with");
+
+ SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+ SmallVector<Value> newIndex;
+ newIndex.reserve(numLinArgs);
+ SmallVector<OpFoldResult> newBasis;
+ newBasis.reserve(numLinArgs);
+ unsigned prevMatchEnd = 0;
+ for (Match m : matches) {
+ unsigned gap = m.linStart - prevMatchEnd;
+ llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+ llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+ // Update here so we don't forget this during early continues
+ prevMatchEnd = m.linStart + m.length;
+
+ // We use the slice from the linearize's basis above because of the
+ // "bounds inferred from `disjoint`" case above.
+ OpFoldResult newSize =
+ computeProduct(linearizeOp.getLoc(), rewriter,
+ linBasisRef.slice(m.linStart, m.length));
+
+ // Trivial case where we can just skip past the delinearize all together
+ if (m.length == m.delinearize.getNumResults()) {
+ newIndex.push_back(m.delinearize.getLinearIndex());
+ newBasis.push_back(newSize);
+ continue;
+ }
+ SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+ newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+ newDelinBasis.begin() + m.delinStart + m.length);
+ newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+ auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ newDelinBasis);
+
+ // Swap all the uses of the unaffected delinearize outputs to the new
+ // delinearization so that the old code can be removed if this
+ // linearize_index is the only user of the merged results.
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().take_front(m.delinStart),
+ newDelinearize.getResults().take_front(m.delinStart)));
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().drop_front(m.delinStart + m.length),
+ newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+ Value newLinArg = newDelinearize.getResult(m.delinStart);
+ newIndex.push_back(newLinArg);
+ newBasis.push_back(newSize);
+ }
+ llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+ llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+ rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+ linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
- rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ for (auto [from, to] : delinearizeReplacements)
+ rewriter.replaceAllUsesWith(from, to);
return success();
}
};
@@ -5049,7 +5241,7 @@ struct DropLinearizeLeadingZero final
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..c153d32670d574 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this @krzysz00 !
/// Return the vector with one basis element per index operand of the operation. | ||
/// If there is no outer bound specified, the leading entry of this basis will be | ||
/// nullptr. | ||
SmallVector<OpFoldResult> getPaddedBasis(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would getPaddedMixedBasis()
make more sense for the name here ?
And accordingly we can make the doc comment simpler : "Same as getMixedBasis, but the leading entry will include a nullptr if no outer bound is specified
" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might, and ... yeah, I'm not strongly attached to the name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also it's a bit wordier - let me know if you have strong thoughts here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, no worries - but I'd still prefer making the doc comment simpler : Same as getMixedBasis, but the leading entry will include a nullptr if no outer bound is specified
.
7f014fb
to
6a8d42d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to make the current pattern work for the following case :-
func.func private @myfunc(%arg0: index) -> (index, index, index) {
%c0 = arith.constant 0 : index
%0:3 = affine.delinearize_index %arg0 into (8, 8, 4) : index, index, index
%1 = affine.linearize_index disjoint [%0#2, %0#0, %c0, %c0] by (4, 8, 4, 8) : index
%2 = affine.linearize_index disjoint [%0#1, %0#2, %c0, %c0] by (8, 4, 8, 4) : index
%3 = affine.linearize_index disjoint [%0#1, %0#0, %c0, %c0] by (8, 8, 4, 4) : index
return %1, %2, %3 : index, index, index
}
Basically this is the case where multiple affine.linearize_index
feed off of SAME affine.delinearize_index
. Currently this causes operand <NUM> does not dominate this use
resulting in an invalid IR.
Perhaps some missing setInsertionPoint*
in the pattern.
The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds). This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work. This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created. Therefore, this commit generalizes the existing simplification logic. Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when `linearize_index disjoint` implies a bound not present on the `delinearize_index`) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away. That is, we can now simplify %0:2 = affine.delinearize_index %n into (8, 8) : inde, index %1 = affine.linearize_index [%x, %0#0, %0#1, %y] by (3, 8, 8, 5) : index to the simpler %1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index This new pattern also works with dynamically-sized basis values. While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing potentially-underspecified basis elements simpler in some cases.
…handling of the residual
6a8d42d
to
8c53d39
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @krzysz00 for the changes!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/55/builds/5033 Here is the relevant piece of the build log for the reference
|
The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds).
This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work.
This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created.
Therefore, this commit generalizes the existing simplification logic.
Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when
linearize_index disjoint
implies a bound not present on thedelinearize_index
) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away.That is, we can now simplify
to the simpler
This new pattern also works with dynamically-sized basis values.
While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.