Skip to content

Commit

Permalink
[mlir][Affine] Extend linearize/delinearize cancelation to partial ta…
Browse files Browse the repository at this point in the history
…ils (llvm#116872)

xisting patterns would cancel out the linearize_index /
delinearize_index pairs that had the exact same basis, like

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...

This commit extends the canonicalization to handle instances where the
entire basis doesn't match, as in

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...

where we can replace the last two results of the delinearize_index
operation with the last two inputs of the linearize_index, creating a
more canonical (fewer total computations to perform) result.
  • Loading branch information
krzysz00 authored Nov 21, 2024
1 parent 6f68d03 commit 0ac889b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
56 changes: 45 additions & 11 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis
};

/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
/// disjoint` and the two operations have the same basis, replace the
/// delinearizeation results with the inputs of the `affine.linearize_index`
/// since they are exact inverses of each other.
/// disjoint` and the two operations end with the same basis elements,
/// cancel those parts of the operations out because they are inverses
/// of each other.
///
/// If the operations have the same basis, cancel them entirely.
///
/// The `disjoint` flag is needed on the `affine.linearize_index` because
/// otherwise, there is no guarantee that the inputs to the linearization are
/// in-bounds the way the outputs of the delinearization would be.
struct CancelDelinearizeOfLinearizeDisjointExact
struct CancelDelinearizeOfLinearizeDisjointExactTail
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -4685,22 +4687,54 @@ struct CancelDelinearizeOfLinearizeDisjointExact
return rewriter.notifyMatchFailure(delinearizeOp,
"index doesn't come from linearize");

if (!linearizeOp.getDisjoint() ||
linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
if (!linearizeOp.getDisjoint())
return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");

ValueRange linearizeIns = linearizeOp.getMultiIndex();
// Note: we use the full basis so we don't lose outer bounds later.
SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
size_t numMatches = 0;
for (auto [linSize, delinSize] : llvm::zip(
llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
if (linSize != delinSize)
break;
++numMatches;
}

if (numMatches == 0)
return rewriter.notifyMatchFailure(
linearizeOp, "not disjoint or basis doesn't match delinearize");
delinearizeOp, "final basis element doesn't match linearize");

// The easy case: everything lines up and the basis match sup completely.
if (numMatches == linearizeBasis.size() &&
numMatches == delinearizeBasis.size() &&
linearizeIns.size() == delinearizeOp.getNumResults()) {
rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
return success();
}

rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
linearizeOp.getDisjoint());
auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
delinearizeOp.getLoc(), newLinearize,
ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
delinearizeOp.hasOuterBound());
SmallVector<Value> mergedResults(newDelinearize.getResults());
mergedResults.append(linearizeIns.take_back(numMatches).begin(),
linearizeIns.take_back(numMatches).end());
rewriter.replaceOp(delinearizeOp, mergedResults);
return success();
}
};
} // namespace

void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns
.insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
context);
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
DropUnitExtentBasis>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0:

// -----

// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index
// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]
func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
%1:3 = affine.delinearize_index %0 into (8, %arg4)
: index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}

// -----

// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
Expand Down

0 comments on commit 0ac889b

Please sign in to comment.