Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Affine] Generalize the linearize(delinearize()) simplifications #117637

Merged

Conversation

krzysz00
Copy link
Contributor

The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds).

This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work.

This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created.

Therefore, this commit generalizes the existing simplification logic.

Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when linearize_index disjoint implies a bound not present on the delinearize_index) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away.

That is, we can now simplify

%0:2 = affine.delinearize_index %n into (8, 8) : inde, index
%1 = affine.linearize_index [%x, %0#0, %0#1, %y] by (3, 8, 8, 5) : index

to the simpler

%1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index

This new pattern also works with dynamically-sized basis values.

While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds).

This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work.

This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created.

Therefore, this commit generalizes the existing simplification logic.

Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when linearize_index disjoint implies a bound not present on the delinearize_index) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away.

That is, we can now simplify

%0:2 = affine.delinearize_index %n into (8, 8) : inde, index
%1 = affine.linearize_index [%x, %0#<!-- -->0, %0#<!-- -->1, %y] by (3, 8, 8, 5) : index

to the simpler

%1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index

This new pattern also works with dynamically-sized basis values.

While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.


Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117637.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+26-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+221-29)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+212-12)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..0d1e4ede795ce5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     %indices_2 = affine.apply #map2()[%linear_index]
     ```
 
+    In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+    `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
     The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
     If there are N basis elements, the first one will not be used during computations,
     but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
     ```
 
-    Note that, due to the constraints of affine maps, all the basis elements must
+    Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+    when one of the `OpFoldResult` builders is called but the first element of the
+    basis is `nullptr`, that first element is ignored and the builder proceeds as if
+    there was no outer bound.
+
+    Due to the constraints of affine maps, all the basis elements must
     be strictly positive. A dynamic basis element being 0 or negative causes
     undefined behavior.
   }];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
     SmallVector<OpFoldResult> getEffectiveBasis();
+
+    /// Return the vector with one basis element per result of the operation. If
+    /// there is no outer bound specified, the leading entry of this result will be
+    /// nullptr.
+    SmallVector<OpFoldResult> getPaddedBasis();
   }];
 
   let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
     ```
 
+    In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+    gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
     The basis may either have `N` or `N-1` elements, where `N` is the number of
     inputs to linearize_index. If `N` inputs are provided, the first one is not used
     in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     If all `N` basis elements are provided, the linearize_index operation is said to
     "have an outer bound".
 
+    As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+    element of a set of `OpFoldResult`s passed to the builders of this operation is
+    `nullptr`, that element is ignored.
+
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
     SmallVector<OpFoldResult> getEffectiveBasis();
+
+    /// Return the vector with one basis element per index operand of the operation.
+    /// If there is no outer bound specified, the leading entry of this basis will be
+    /// nullptr.
+    SmallVector<OpFoldResult> getPaddedBasis();
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..b7c5e8eff8a8cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex, ValueRange basis,
                                      bool hasOuterBound) {
+  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+    hasOuterBound = false;
+    basis = basis.drop_front();
+  }
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      Value linearIndex,
                                      ArrayRef<OpFoldResult> basis,
                                      bool hasOuterBound) {
+  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+    hasOuterBound = false;
+    basis = basis.drop_front();
+  }
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4614,6 +4622,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
   return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
 }
 
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+  SmallVector<OpFoldResult> ret = getMixedBasis();
+  if (!hasOuterBound())
+    ret.insert(ret.begin(), OpFoldResult());
+  return ret;
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4632,25 +4647,27 @@ struct DropUnitExtentBasis
       return zero.value();
     };
 
-    bool hasOuterBound = delinearizeOp.hasOuterBound();
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
     SmallVector<OpFoldResult> newBasis;
-    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
-      std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    for (auto [index, basis] :
+         llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+      std::optional<int64_t> basisVal =
+          basis ? getConstantIntValue(basis) : std::nullopt;
       if (basisVal && *basisVal == 1)
-        replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+        replacements[index] = getZero();
       else
         newBasis.push_back(basis);
     }
 
-    if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+    if (newBasis.size() == delinearizeOp.getNumResults())
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "no unit basis elements");
 
-    if (!newBasis.empty() || !hasOuterBound) {
+    if (!newBasis.empty()) {
+      // Will drop the leading nullptr from `basis` if there was no outer bound.
       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-          loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+          loc, delinearizeOp.getLinearIndex(), newBasis);
       int newIndex = 0;
       // Map back the new delinearized indices to the values they replace.
       for (auto &replacement : replacements) {
@@ -4831,6 +4848,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    OperationState &odsState,
                                    ValueRange multiIndex, ValueRange basis,
                                    bool disjoint) {
+  if (!basis.empty() && basis.front() == Value())
+    basis = basis.drop_front();
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4843,6 +4862,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    ValueRange multiIndex,
                                    ArrayRef<OpFoldResult> basis,
                                    bool disjoint) {
+  if (!basis.empty() && basis.front() == OpFoldResult())
+    basis = basis.drop_front();
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4918,7 +4939,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
                           builder);
   }
 
-  return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+  SmallVector<OpFoldResult> ret = getMixedBasis();
+  if (!hasOuterBound())
+    ret.insert(ret.begin(), OpFoldResult());
+  return ret;
 }
 
 namespace {
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+                                   ArrayRef<OpFoldResult> terms) {
+  int64_t nDynamic = 0;
+  SmallVector<Value> dynamicPart;
+  AffineExpr result = builder.getAffineConstantExpr(1);
+  for (OpFoldResult term : terms) {
+    if (!term)
+      return term;
+    std::optional<int64_t> maybeConst = getConstantIntValue(term);
+    if (maybeConst) {
+      result = result * builder.getAffineConstantExpr(*maybeConst);
+    } else {
+      dynamicPart.push_back(term.get<Value>());
+      result = result * builder.getAffineSymbolExpr(nDynamic++);
+    }
+  }
+  if (auto constant = dyn_cast<AffineConstantExpr>(result))
+    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+///   by (...e, B1, B2, ..., BK, ...f)
+/// ```
 ///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
 /// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
 /// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  struct Match {
+    AffineDelinearizeIndexOp delinearize;
+    unsigned linStart = 0;
+    unsigned delinStart = 0;
+    unsigned length = 0;
+  };
+
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto delinearizeOp = linearizeOp.getMultiIndex()
-                             .front()
-                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
-    if (!delinearizeOp)
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "last entry doesn't come from a delinearize");
+    SmallVector<Match> matches;
+
+    const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+    ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+    ValueRange multiIndex = linearizeOp.getMultiIndex();
+    unsigned numLinArgs = multiIndex.size();
+    unsigned linArgIdx = 0;
+    // We only want to replace one run from the same delinearize op per
+    // pattern invocation lest we run into invalidation issues.
+    llvm::SmallPtrSet<Operation *, 2> seen;
+    while (linArgIdx < numLinArgs) {
+      auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+      if (!asResult) {
+        linArgIdx++;
+        continue;
+      }
 
-    if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "basis of linearize and delinearize don't match exactly "
-                       "(excluding outer bounds)");
+      auto delinearizeOp =
+          dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+      if (!delinearizeOp) {
+        linArgIdx++;
+        continue;
+      }
+
+      /// Result 0 of the delinearize and argument 0 of the linearize can
+      /// leave their maximum value unspecified. However, even if this happens
+      /// we can still sometimes start the match process. Specifically, if
+      /// - The argument we're matching is result 0 and argument 0 (so the
+      /// bounds don't matter). For example,
+      ///
+      ///     %0:2 = affine.delinearize_index %x into (8) : index, index
+      ///     %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+      /// allows cancellation
+      /// - The delinearization doesn't specify a bound, but the linearization
+      ///  is `disjoint`, which asserts that the bound on the linearization is
+      ///  correct.
+      unsigned firstDelinArg = asResult.getResultNumber();
+      SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+      OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+      OpFoldResult firstLinBound = linBasis[linArgIdx];
+      bool boundsMatch = firstDelinBound == firstLinBound;
+      bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+      bool knownByDisjoint =
+          linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+      if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+        linArgIdx++;
+        continue;
+      }
 
-    if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+      unsigned j = 1;
+      unsigned numDelinOuts = delinearizeOp.getNumResults();
+      for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+           ++j) {
+        if (multiIndex[linArgIdx + j] !=
+            delinearizeOp.getResult(firstDelinArg + j))
+          break;
+        if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+          break;
+      }
+      // If there're multiple matches against the same delinearize_index,
+      // only rewrite the first one we find to prevent invalidations. The next
+      // ones will be taken caer of by subsequent pattern invocations.
+      if (j <= 1 || !seen.insert(delinearizeOp).second) {
+        linArgIdx++;
+        continue;
+      }
+      matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+      linArgIdx += j;
+    }
+
+    if (matches.empty())
       return rewriter.notifyMatchFailure(
-          linearizeOp, "not all indices come from delinearize");
+          linearizeOp, "no run of delinearize outputs to deal with");
+
+    SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+    SmallVector<Value> newIndex;
+    newIndex.reserve(numLinArgs);
+    SmallVector<OpFoldResult> newBasis;
+    newBasis.reserve(numLinArgs);
+    unsigned prevMatchEnd = 0;
+    for (Match m : matches) {
+      unsigned gap = m.linStart - prevMatchEnd;
+      llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+      llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+      // Update here so we don't forget this during early continues
+      prevMatchEnd = m.linStart + m.length;
+
+      // We use the slice from the linearize's basis above because of the
+      // "bounds inferred from `disjoint`" case above.
+      OpFoldResult newSize =
+          computeProduct(linearizeOp.getLoc(), rewriter,
+                         linBasisRef.slice(m.linStart, m.length));
+
+      // Trivial case where we can just skip past the delinearize all together
+      if (m.length == m.delinearize.getNumResults()) {
+        newIndex.push_back(m.delinearize.getLinearIndex());
+        newBasis.push_back(newSize);
+        continue;
+      }
+      SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+      newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+                          newDelinBasis.begin() + m.delinStart + m.length);
+      newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+      auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+          m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+          newDelinBasis);
+
+      // Swap all the uses of the unaffected delinearize outputs to the new
+      // delinearization so that the old code can be removed if this
+      // linearize_index is the only user of the merged results.
+      llvm::append_range(
+          delinearizeReplacements,
+          llvm::zip_equal(
+              m.delinearize.getResults().take_front(m.delinStart),
+              newDelinearize.getResults().take_front(m.delinStart)));
+      llvm::append_range(
+          delinearizeReplacements,
+          llvm::zip_equal(
+              m.delinearize.getResults().drop_front(m.delinStart + m.length),
+              newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+      Value newLinArg = newDelinearize.getResult(m.delinStart);
+      newIndex.push_back(newLinArg);
+      newBasis.push_back(newSize);
+    }
+    llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+    llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+    rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+        linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
 
-    rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+    for (auto [from, to] : delinearizeReplacements)
+      rewriter.replaceAllUsesWith(from, to);
     return success();
   }
 };
@@ -5049,7 +5241,7 @@ struct DropLinearizeLeadingZero final
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..c153d32670d574 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
 
 // -----
 
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
 //       CHECK:     return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
   %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
   %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
   return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
 
 // -----
 
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds).

This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work.

This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created.

Therefore, this commit generalizes the existing simplification logic.

Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when linearize_index disjoint implies a bound not present on the delinearize_index) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away.

That is, we can now simplify

%0:2 = affine.delinearize_index %n into (8, 8) : inde, index
%1 = affine.linearize_index [%x, %0#<!-- -->0, %0#<!-- -->1, %y] by (3, 8, 8, 5) : index

to the simpler

%1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index

This new pattern also works with dynamically-sized basis values.

While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.


Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117637.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+26-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+221-29)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+212-12)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..0d1e4ede795ce5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     %indices_2 = affine.apply #map2()[%linear_index]
     ```
 
+    In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+    `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
     The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
     If there are N basis elements, the first one will not be used during computations,
     but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
     ```
 
-    Note that, due to the constraints of affine maps, all the basis elements must
+    Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+    when one of the `OpFoldResult` builders is called but the first element of the
+    basis is `nullptr`, that first element is ignored and the builder proceeds as if
+    there was no outer bound.
+
+    Due to the constraints of affine maps, all the basis elements must
     be strictly positive. A dynamic basis element being 0 or negative causes
     undefined behavior.
   }];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
     SmallVector<OpFoldResult> getEffectiveBasis();
+
+    /// Return the vector with one basis element per result of the operation. If
+    /// there is no outer bound specified, the leading entry of this result will be
+    /// nullptr.
+    SmallVector<OpFoldResult> getPaddedBasis();
   }];
 
   let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
     ```
 
+    In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+    gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
     The basis may either have `N` or `N-1` elements, where `N` is the number of
     inputs to linearize_index. If `N` inputs are provided, the first one is not used
     in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     If all `N` basis elements are provided, the linearize_index operation is said to
     "have an outer bound".
 
+    As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+    element of a set of `OpFoldResult`s passed to the builders of this operation is
+    `nullptr`, that element is ignored.
+
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
     SmallVector<OpFoldResult> getEffectiveBasis();
+
+    /// Return the vector with one basis element per index operand of the operation.
+    /// If there is no outer bound specified, the leading entry of this basis will be
+    /// nullptr.
+    SmallVector<OpFoldResult> getPaddedBasis();
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..b7c5e8eff8a8cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex, ValueRange basis,
                                      bool hasOuterBound) {
+  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+    hasOuterBound = false;
+    basis = basis.drop_front();
+  }
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      Value linearIndex,
                                      ArrayRef<OpFoldResult> basis,
                                      bool hasOuterBound) {
+  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+    hasOuterBound = false;
+    basis = basis.drop_front();
+  }
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4614,6 +4622,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
   return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
 }
 
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+  SmallVector<OpFoldResult> ret = getMixedBasis();
+  if (!hasOuterBound())
+    ret.insert(ret.begin(), OpFoldResult());
+  return ret;
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4632,25 +4647,27 @@ struct DropUnitExtentBasis
       return zero.value();
     };
 
-    bool hasOuterBound = delinearizeOp.hasOuterBound();
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
     SmallVector<OpFoldResult> newBasis;
-    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
-      std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    for (auto [index, basis] :
+         llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+      std::optional<int64_t> basisVal =
+          basis ? getConstantIntValue(basis) : std::nullopt;
       if (basisVal && *basisVal == 1)
-        replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+        replacements[index] = getZero();
       else
         newBasis.push_back(basis);
     }
 
-    if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+    if (newBasis.size() == delinearizeOp.getNumResults())
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "no unit basis elements");
 
-    if (!newBasis.empty() || !hasOuterBound) {
+    if (!newBasis.empty()) {
+      // Will drop the leading nullptr from `basis` if there was no outer bound.
       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-          loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+          loc, delinearizeOp.getLinearIndex(), newBasis);
       int newIndex = 0;
       // Map back the new delinearized indices to the values they replace.
       for (auto &replacement : replacements) {
@@ -4831,6 +4848,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    OperationState &odsState,
                                    ValueRange multiIndex, ValueRange basis,
                                    bool disjoint) {
+  if (!basis.empty() && basis.front() == Value())
+    basis = basis.drop_front();
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4843,6 +4862,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    ValueRange multiIndex,
                                    ArrayRef<OpFoldResult> basis,
                                    bool disjoint) {
+  if (!basis.empty() && basis.front() == OpFoldResult())
+    basis = basis.drop_front();
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4918,7 +4939,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
                           builder);
   }
 
-  return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+  SmallVector<OpFoldResult> ret = getMixedBasis();
+  if (!hasOuterBound())
+    ret.insert(ret.begin(), OpFoldResult());
+  return ret;
 }
 
 namespace {
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+                                   ArrayRef<OpFoldResult> terms) {
+  int64_t nDynamic = 0;
+  SmallVector<Value> dynamicPart;
+  AffineExpr result = builder.getAffineConstantExpr(1);
+  for (OpFoldResult term : terms) {
+    if (!term)
+      return term;
+    std::optional<int64_t> maybeConst = getConstantIntValue(term);
+    if (maybeConst) {
+      result = result * builder.getAffineConstantExpr(*maybeConst);
+    } else {
+      dynamicPart.push_back(term.get<Value>());
+      result = result * builder.getAffineSymbolExpr(nDynamic++);
+    }
+  }
+  if (auto constant = dyn_cast<AffineConstantExpr>(result))
+    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+///   by (...e, B1, B2, ..., BK, ...f)
+/// ```
 ///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
 /// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
 /// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  struct Match {
+    AffineDelinearizeIndexOp delinearize;
+    unsigned linStart = 0;
+    unsigned delinStart = 0;
+    unsigned length = 0;
+  };
+
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto delinearizeOp = linearizeOp.getMultiIndex()
-                             .front()
-                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
-    if (!delinearizeOp)
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "last entry doesn't come from a delinearize");
+    SmallVector<Match> matches;
+
+    const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+    ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+    ValueRange multiIndex = linearizeOp.getMultiIndex();
+    unsigned numLinArgs = multiIndex.size();
+    unsigned linArgIdx = 0;
+    // We only want to replace one run from the same delinearize op per
+    // pattern invocation lest we run into invalidation issues.
+    llvm::SmallPtrSet<Operation *, 2> seen;
+    while (linArgIdx < numLinArgs) {
+      auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+      if (!asResult) {
+        linArgIdx++;
+        continue;
+      }
 
-    if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "basis of linearize and delinearize don't match exactly "
-                       "(excluding outer bounds)");
+      auto delinearizeOp =
+          dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+      if (!delinearizeOp) {
+        linArgIdx++;
+        continue;
+      }
+
+      /// Result 0 of the delinearize and argument 0 of the linearize can
+      /// leave their maximum value unspecified. However, even if this happens
+      /// we can still sometimes start the match process. Specifically, if
+      /// - The argument we're matching is result 0 and argument 0 (so the
+      /// bounds don't matter). For example,
+      ///
+      ///     %0:2 = affine.delinearize_index %x into (8) : index, index
+      ///     %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+      /// allows cancellation
+      /// - The delinearization doesn't specify a bound, but the linearization
+      ///  is `disjoint`, which asserts that the bound on the linearization is
+      ///  correct.
+      unsigned firstDelinArg = asResult.getResultNumber();
+      SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+      OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+      OpFoldResult firstLinBound = linBasis[linArgIdx];
+      bool boundsMatch = firstDelinBound == firstLinBound;
+      bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+      bool knownByDisjoint =
+          linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+      if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+        linArgIdx++;
+        continue;
+      }
 
-    if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+      unsigned j = 1;
+      unsigned numDelinOuts = delinearizeOp.getNumResults();
+      for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+           ++j) {
+        if (multiIndex[linArgIdx + j] !=
+            delinearizeOp.getResult(firstDelinArg + j))
+          break;
+        if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+          break;
+      }
+      // If there're multiple matches against the same delinearize_index,
+      // only rewrite the first one we find to prevent invalidations. The next
+      // ones will be taken caer of by subsequent pattern invocations.
+      if (j <= 1 || !seen.insert(delinearizeOp).second) {
+        linArgIdx++;
+        continue;
+      }
+      matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+      linArgIdx += j;
+    }
+
+    if (matches.empty())
       return rewriter.notifyMatchFailure(
-          linearizeOp, "not all indices come from delinearize");
+          linearizeOp, "no run of delinearize outputs to deal with");
+
+    SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+    SmallVector<Value> newIndex;
+    newIndex.reserve(numLinArgs);
+    SmallVector<OpFoldResult> newBasis;
+    newBasis.reserve(numLinArgs);
+    unsigned prevMatchEnd = 0;
+    for (Match m : matches) {
+      unsigned gap = m.linStart - prevMatchEnd;
+      llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+      llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+      // Update here so we don't forget this during early continues
+      prevMatchEnd = m.linStart + m.length;
+
+      // We use the slice from the linearize's basis above because of the
+      // "bounds inferred from `disjoint`" case above.
+      OpFoldResult newSize =
+          computeProduct(linearizeOp.getLoc(), rewriter,
+                         linBasisRef.slice(m.linStart, m.length));
+
+      // Trivial case where we can just skip past the delinearize all together
+      if (m.length == m.delinearize.getNumResults()) {
+        newIndex.push_back(m.delinearize.getLinearIndex());
+        newBasis.push_back(newSize);
+        continue;
+      }
+      SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+      newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+                          newDelinBasis.begin() + m.delinStart + m.length);
+      newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+      auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+          m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+          newDelinBasis);
+
+      // Swap all the uses of the unaffected delinearize outputs to the new
+      // delinearization so that the old code can be removed if this
+      // linearize_index is the only user of the merged results.
+      llvm::append_range(
+          delinearizeReplacements,
+          llvm::zip_equal(
+              m.delinearize.getResults().take_front(m.delinStart),
+              newDelinearize.getResults().take_front(m.delinStart)));
+      llvm::append_range(
+          delinearizeReplacements,
+          llvm::zip_equal(
+              m.delinearize.getResults().drop_front(m.delinStart + m.length),
+              newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+      Value newLinArg = newDelinearize.getResult(m.delinStart);
+      newIndex.push_back(newLinArg);
+      newBasis.push_back(newSize);
+    }
+    llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+    llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+    rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+        linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
 
-    rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+    for (auto [from, to] : delinearizeReplacements)
+      rewriter.replaceAllUsesWith(from, to);
     return success();
   }
 };
@@ -5049,7 +5241,7 @@ struct DropLinearizeLeadingZero final
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..c153d32670d574 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
 
 // -----
 
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
 //       CHECK:     return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
   %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
   %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
   return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
 
 // -----
 
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]...
[truncated]

Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @krzysz00 !

/// Return the vector with one basis element per index operand of the operation.
/// If there is no outer bound specified, the leading entry of this basis will be
/// nullptr.
SmallVector<OpFoldResult> getPaddedBasis();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would getPaddedMixedBasis() make more sense for the name here ?
And accordingly we can make the doc comment simpler : "Same as getMixedBasis, but the leading entry will include a nullptr if no outer bound is specified" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might, and ... yeah, I'm not strongly attached to the name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also it's a bit wordier - let me know if you have strong thoughts here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, no worries - but I'd still prefer making the doc comment simpler : Same as getMixedBasis, but the leading entry will include a nullptr if no outer bound is specified.

@krzysz00 krzysz00 force-pushed the collapse-contiguous-merges-linearize-delinearize branch from 7f014fb to 6a8d42d Compare December 2, 2024 17:59
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to make the current pattern work for the following case :-

func.func private @myfunc(%arg0: index) -> (index, index, index) {
  %c0 = arith.constant 0 : index
  %0:3 = affine.delinearize_index %arg0 into (8, 8, 4) : index, index, index
  %1 = affine.linearize_index disjoint [%0#2, %0#0, %c0, %c0] by (4, 8, 4, 8) : index
  %2 = affine.linearize_index disjoint [%0#1, %0#2, %c0, %c0] by (8, 4, 8, 4) : index
  %3 = affine.linearize_index disjoint [%0#1, %0#0, %c0, %c0] by (8, 8, 4, 4) : index
  return %1, %2, %3 : index, index, index
}

Basically this is the case where multiple affine.linearize_index feed off of SAME affine.delinearize_index. Currently this causes operand <NUM> does not dominate this use resulting in an invalid IR.

Perhaps some missing setInsertionPoint* in the pattern.

The existing canonicalization patterns would only cancel out cases
where the entire result list of an affine.delineraize_index was passed
to an affine.lineraize_index and the basis elements matched
exactly (except possibly for the outer bounds).

This was correct, but limited, and left open many cases where a
delinearize_index would take a series of divisions and modulos only
for a subsequent linearize_index to use additions and multiplications
to undo all that work.

This sort of simplification is reasably easy to observe at the level
of splititng and merging indexes, but difficult to perform once the
underlying arithmetic operations have been created.

Therefore, this commit generalizes the existing simplification logic.

Now, any run of two or more delinearize_index results that appears
within the argument list of a linearize_index operation with the same
basis (or where they're both at the outermost position and so can be
unbonded, or when `linearize_index disjoint` implies a bound not
present on the `delinearize_index`) will be reduced to one signle
delinearize_index output, whose basis element (that is, size or
length) is equal to the product of the sizes that were simplified
away.

That is, we can now simplify

    %0:2 = affine.delinearize_index %n into (8, 8) : inde, index
    %1 = affine.linearize_index [%x, %0#0, %0#1, %y] by (3, 8, 8, 5) : index

to the simpler

    %1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index

This new pattern also works with dynamically-sized basis values.

While I'm here, I fixed a bunch of typos in existing tests, and added
a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.
@krzysz00 krzysz00 force-pushed the collapse-contiguous-merges-linearize-delinearize branch from 6a8d42d to 8c53d39 Compare December 4, 2024 18:44
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @krzysz00 for the changes!

@krzysz00 krzysz00 changed the title [mlir][Affine] Genarilze the linearize(delinearize()) simplifications [mlir][Affine] Generalize the linearize(delinearize()) simplifications Jan 3, 2025
@krzysz00 krzysz00 merged commit 9f5cefe into llvm:main Jan 3, 2025
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jan 3, 2025

LLVM Buildbot has detected a new failure on builder sanitizer-aarch64-linux-bootstrap-hwasan running on sanitizer-buildbot12 while building mlir at step 2 "annotate".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/55/builds/5033

Here is the relevant piece of the build log for the reference
Step 2 (annotate) failure: 'python ../sanitizer_buildbot/sanitizers/zorg/buildbot/builders/sanitizers/buildbot_selector.py' (failure)
...
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using lld-link: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/lld-link
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld64.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld64.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using wasm-ld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/wasm-ld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using lld-link: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/lld-link
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld64.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld64.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using wasm-ld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/wasm-ld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/main.py:72: note: The test suite configuration requested an individual test timeout of 0 seconds but a timeout of 900 seconds was requested on the command line. Forcing timeout to be 900 seconds.
-- Testing: 85697 tests, 72 workers --
Testing:  0.. 10.. 20.. 30.. 40.. 50.. 
FAIL: LLVM :: ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s (52291 of 85697)
******************** TEST 'LLVM :: ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
RUN: at line 1: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-mc -filetype=obj -triple=x86_64-windows-msvc /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s -o /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp
+ /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-mc -filetype=obj -triple=x86_64-windows-msvc /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s -o /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp
RUN: at line 2: not /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-jitlink -noexec /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp 2>&1 | /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/FileCheck /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s
+ /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/FileCheck /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s
+ not /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-jitlink -noexec /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp

--

********************
Testing:  0.. 10.. 20.. 30.. 40.. 50.. 60.. 70.. 80.. 90.. 
Slowest Tests:
--------------------------------------------------------------------------
56.42s: Clang :: Driver/fsanitize.c
43.37s: Clang :: Preprocessor/riscv-target-features.c
39.81s: Clang :: Driver/arm-cortex-cpus-2.c
38.72s: Clang :: Driver/arm-cortex-cpus-1.c
35.77s: LLVM :: CodeGen/AMDGPU/sched-group-barrier-pipeline-solver.mir
35.22s: Clang :: OpenMP/target_update_codegen.cpp
34.96s: Clang :: OpenMP/target_defaultmap_codegen_01.cpp
31.04s: Clang :: Preprocessor/aarch64-target-features.c
30.65s: Clang :: Preprocessor/arm-target-features.c
27.94s: Clang :: Driver/clang_f_opts.c
26.98s: Clang :: Preprocessor/predefined-arch-macros.c
26.30s: Clang :: Driver/linux-ld.c
25.74s: LLVM :: CodeGen/RISCV/attributes.ll
23.94s: LLVM :: CodeGen/ARM/build-attributes.ll
23.82s: LLVM :: tools/llvm-reduce/parallel-workitem-kill.ll
23.62s: Clang :: Driver/cl-options.c
22.57s: Clang :: Driver/x86-target-features.c
22.25s: Clang :: Analysis/a_flaky_crash.cpp
20.62s: LLVM :: CodeGen/AMDGPU/memintrinsic-unroll.ll
19.30s: Clang :: Driver/debug-options.c

Step 11 (stage2/hwasan check) failure: stage2/hwasan check (failure)
...
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using lld-link: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/lld-link
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld64.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld64.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using wasm-ld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/wasm-ld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using lld-link: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/lld-link
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using ld64.lld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/ld64.lld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/llvm/config.py:506: note: using wasm-ld: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/wasm-ld
llvm-lit: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/utils/lit/lit/main.py:72: note: The test suite configuration requested an individual test timeout of 0 seconds but a timeout of 900 seconds was requested on the command line. Forcing timeout to be 900 seconds.
-- Testing: 85697 tests, 72 workers --
Testing:  0.. 10.. 20.. 30.. 40.. 50.. 
FAIL: LLVM :: ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s (52291 of 85697)
******************** TEST 'LLVM :: ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
RUN: at line 1: /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-mc -filetype=obj -triple=x86_64-windows-msvc /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s -o /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp
+ /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-mc -filetype=obj -triple=x86_64-windows-msvc /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s -o /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp
RUN: at line 2: not /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-jitlink -noexec /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp 2>&1 | /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/FileCheck /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s
+ /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/FileCheck /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm-project/llvm/test/ExecutionEngine/JITLink/x86-64/COFF_directive_alternatename_fail.s
+ not /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/bin/llvm-jitlink -noexec /home/b/sanitizer-aarch64-linux-bootstrap-hwasan/build/llvm_build_hwasan/test/ExecutionEngine/JITLink/x86-64/Output/COFF_directive_alternatename_fail.s.tmp

--

********************
Testing:  0.. 10.. 20.. 30.. 40.. 50.. 60.. 70.. 80.. 90.. 
Slowest Tests:
--------------------------------------------------------------------------
56.42s: Clang :: Driver/fsanitize.c
43.37s: Clang :: Preprocessor/riscv-target-features.c
39.81s: Clang :: Driver/arm-cortex-cpus-2.c
38.72s: Clang :: Driver/arm-cortex-cpus-1.c
35.77s: LLVM :: CodeGen/AMDGPU/sched-group-barrier-pipeline-solver.mir
35.22s: Clang :: OpenMP/target_update_codegen.cpp
34.96s: Clang :: OpenMP/target_defaultmap_codegen_01.cpp
31.04s: Clang :: Preprocessor/aarch64-target-features.c
30.65s: Clang :: Preprocessor/arm-target-features.c
27.94s: Clang :: Driver/clang_f_opts.c
26.98s: Clang :: Preprocessor/predefined-arch-macros.c
26.30s: Clang :: Driver/linux-ld.c
25.74s: LLVM :: CodeGen/RISCV/attributes.ll
23.94s: LLVM :: CodeGen/ARM/build-attributes.ll
23.82s: LLVM :: tools/llvm-reduce/parallel-workitem-kill.ll
23.62s: Clang :: Driver/cl-options.c
22.57s: Clang :: Driver/x86-target-features.c
22.25s: Clang :: Analysis/a_flaky_crash.cpp
20.62s: LLVM :: CodeGen/AMDGPU/memintrinsic-unroll.ll
19.30s: Clang :: Driver/debug-options.c


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants