Skip to content

[MLIR][Affine] Fix memref replacement in affine-data-copy-generate #139016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,9 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
/// of its input list. `indexRemap`'s dimensional inputs are expected to
/// correspond to memref's indices, and its symbolic inputs if any should be
/// provided in `symbolOperands`.
///
/// `domOpFilter`, if non-null, restricts the replacement to only those
/// operations that are dominated by the former; similarly, `postDomOpFilter`
/// restricts replacement to only those operations that are postdominated by it.
//
/// If `userFilterFn` is specified, restrict replacement to only those users
/// that pass the specified filter (i.e., the filter returns true).
///
/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing
/// uses of a memref without any requirement for access index rewrites as long
Expand All @@ -224,13 +223,14 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
// d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any
// extra operands, note that 'indexRemap' would just be applied to existing
// indices (%i, %j).
//
// TODO: allow extraIndices to be added at any position.
LogicalResult replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices = {},
AffineMap indexRemap = AffineMap(), ArrayRef<Value> extraOperands = {},
ArrayRef<Value> symbolOperands = {}, Operation *domOpFilter = nullptr,
Operation *postDomOpFilter = nullptr, bool allowNonDereferencingOps = false,
bool replaceInDeallocOp = false);
ArrayRef<Value> symbolOperands = {},
llvm::function_ref<bool(Operation *)> userFilterFn = nullptr,
bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false);

/// Performs the same replacement as the other version above but only for the
/// dereferencing uses of `oldMemRef` in `op`, except in cases where
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,15 @@ static Value createPrivateMemRef(AffineForOp forOp,
// Replace all users of 'oldMemRef' with 'newMemRef'.
Operation *domFilter =
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
auto userFilterFn = [&](Operation *user) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we compute dominance info inside a static utility? It is expensive, and I thought the convention is to compute it once in the entry function of a pass, and pass it around.

auto domInfo = std::make_unique<DominanceInfo>(
domFilter->getParentOfType<FunctionOpInterface>());
return domInfo->dominates(domFilter, user);
};
LogicalResult res = replaceAllMemRefUsesWith(
oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*symbolOperands=*/{}, domFilter);
/*symbolOperands=*/{}, userFilterFn);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {

// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
if (failed(replaceAllMemRefUsesWith(
oldMemRef, newMemRef,
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/&*forOp.getBody()->begin()))) {
auto userFilterFn = [&](Operation *user) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we compute dominance info inside a static utility? It is expensive, and I thought the convention is to compute it once in the entry function of a pass, and pass it around.

auto domInfo = std::make_unique<DominanceInfo>(
forOp->getParentOfType<FunctionOpInterface>());
return domInfo->dominates(&*forOp.getBody()->begin(), user);
};
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*symbolOperands=*/{}, userFilterFn))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));
ivModTwoOp.erase();
Expand Down
22 changes: 13 additions & 9 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,12 @@ static LogicalResult generateCopy(
if (begin == end)
return success();

// Record the last op in the block for which we are performing copy
// generation. We later do the memref replacement only in [begin, lastCopyOp]
// so that the original memref's used in the data movement code themselves
// don't get replaced.
Operation *lastCopyOp = end->getPrevNode();

// Is the copy out point at the end of the block where we are doing
// explicit copying.
bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
Expand Down Expand Up @@ -2145,12 +2151,6 @@ static LogicalResult generateCopy(
}
}

// Record the last operation where we want the memref replacement to end. We
// later do the memref replacement only in [begin, postDomFilter] so
// that the original memref's used in the data movement code themselves don't
// get replaced.
auto postDomFilter = std::prev(end);

// Create fully composed affine maps for each memref.
auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
Expand Down Expand Up @@ -2246,13 +2246,17 @@ static LogicalResult generateCopy(
if (!isBeginAtStartOfBlock)
prevOfBegin = std::prev(begin);

auto userFilterFn = [&](Operation *user) {
auto *ancestorUser = block->findAncestorOpInBlock(*user);
return ancestorUser && !ancestorUser->isBeforeInBlock(&*begin) &&
!lastCopyOp->isBeforeInBlock(ancestorUser);
};

// *Only* those uses within the range [begin, end) of 'block' are replaced.
(void)replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
/*symbolOperands=*/{},
/*domOpFilter=*/&*begin,
/*postDomOpFilter=*/&*postDomFilter);
/*symbolOperands=*/{}, userFilterFn);

*nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);

Expand Down
55 changes: 22 additions & 33 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
AffineMap indexRemap, ArrayRef<Value> extraOperands,
ArrayRef<Value> symbolOperands, Operation *domOpFilter,
Operation *postDomOpFilter, bool allowNonDereferencingOps,
bool replaceInDeallocOp) {
ArrayRef<Value> symbolOperands,
llvm::function_ref<bool(Operation *)> userFilterFn,
bool allowNonDereferencingOps, bool replaceInDeallocOp) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
Expand All @@ -1328,61 +1328,52 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(

std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
if (domOpFilter)
domInfo = std::make_unique<DominanceInfo>(
domOpFilter->getParentOfType<FunctionOpInterface>());

if (postDomOpFilter)
postDomInfo = std::make_unique<PostDominanceInfo>(
postDomOpFilter->getParentOfType<FunctionOpInterface>());

// Walk all uses of old memref; collect ops to perform replacement. We use a
// DenseSet since an operation could potentially have multiple uses of a
// memref (although rare), and the replacement later is going to erase ops.
DenseSet<Operation *> opsToReplace;
for (auto *op : oldMemRef.getUsers()) {
// Skip this use if it's not dominated by domOpFilter.
if (domOpFilter && !domInfo->dominates(domOpFilter, op))
continue;

// Skip this use if it's not post-dominated by postDomOpFilter.
if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
for (auto *user : oldMemRef.getUsers()) {
// Check if this user doesn't pass the filter.
if (userFilterFn && !userFilterFn(user))
continue;

// Skip dealloc's - no replacement is necessary, and a memref replacement
// at other uses doesn't hurt these dealloc's.
if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) &&
if (hasSingleEffect<MemoryEffects::Free>(user, oldMemRef) &&
!replaceInDeallocOp)
continue;

// Check if the memref was used in a non-dereferencing context. It is fine
// for the memref to be used in a non-dereferencing way outside of the
// region where this replacement is happening.
if (!isa<AffineMapAccessInterface>(*op)) {
if (!isa<AffineMapAccessInterface>(*user)) {
if (!allowNonDereferencingOps) {
LLVM_DEBUG(llvm::dbgs()
<< "Memref replacement failed: non-deferencing memref op: \n"
<< *op << '\n');
LLVM_DEBUG(
llvm::dbgs()
<< "Memref replacement failed: non-deferencing memref user: \n"
<< *user << '\n');
return failure();
}
// Non-dereferencing ops with the MemRefsNormalizable trait are
// supported for replacement.
if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
if (!user->hasTrait<OpTrait::MemRefsNormalizable>()) {
LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
"memrefs normalizable trait: \n"
<< *op << '\n');
<< *user << '\n');
return failure();
}
}

// We'll first collect and then replace --- since replacement erases the op
// that has the use, and that op could be postDomFilter or domFilter itself!
opsToReplace.insert(op);
// We'll first collect and then replace --- since replacement erases the
// user that has the use, and that user could be postDomFilter or domFilter
// itself!
opsToReplace.insert(user);
}

for (auto *op : opsToReplace) {
for (auto *user : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(
oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
oldMemRef, newMemRef, user, extraIndices, indexRemap, extraOperands,
symbolOperands, allowNonDereferencingOps)))
llvm_unreachable("memref replacement guaranteed to succeed here");
}
Expand Down Expand Up @@ -1763,8 +1754,7 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/symbolOperands,
/*domOpFilter=*/nullptr,
/*postDomOpFilter=*/nullptr,
/*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true))) {
// If it failed (due to escapes for example), bail out.
newAlloc.erase();
Expand Down Expand Up @@ -1854,8 +1844,7 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
/*indexRemap=*/oldLayoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/oldStrides,
/*domOpFilter=*/nullptr,
/*postDomOpFilter=*/nullptr,
/*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true))) {
// If it failed (due to escapes for example), bail out.
newReinterpretCast.erase();
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/nullptr,
/*postDomOpFilter=*/nullptr,
/*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
// If it failed (due to escapes for example), bail out.
Expand Down Expand Up @@ -407,8 +406,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/nullptr,
/*postDomOpFilter=*/nullptr,
/*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
// If it failed (due to escapes for example), bail out. Removing the
Expand Down Expand Up @@ -457,8 +455,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/nullptr,
/*postDomOpFilter=*/nullptr,
/*userFilterFn=*/nullptr,
/*allowNonDereferencingOps=*/true,
/*replaceInDeallocOp=*/true))) {
newOp->erase();
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/Affine/affine-data-copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,51 @@ func.func @memref_def_inside(%arg0: index) {
// LIMITED-MEM-NEXT: memref.dealloc %{{.*}} : memref<1xf32>
return
}

// Test with uses across multiple blocks.

memref.global "private" constant @__constant_1x2x1xi32_1 : memref<1x2x1xi32> = dense<0> {alignment = 64 : i64}

// CHECK-LABEL: func @multiple_blocks
func.func @multiple_blocks(%arg0: index) -> memref<1x2x1xi32> {
%c1_i32 = arith.constant 1 : i32
%c3_i32 = arith.constant 3 : i32
%0 = memref.get_global @__constant_1x2x1xi32_1 : memref<1x2x1xi32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
memref.copy %0, %alloc : memref<1x2x1xi32> to memref<1x2x1xi32>
cf.br ^bb1(%alloc : memref<1x2x1xi32>)
^bb1(%1: memref<1x2x1xi32>): // 2 preds: ^bb0, ^bb2
// CHECK: ^bb1(%[[MEM:.*]]: memref<1x2x1xi32>):
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi1>
// CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x1xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question for my understanding -- Why is the fast buffer not allocated in shared memory space?
?

affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 2 {
affine.for %arg3 = 0 to 1 {
// CHECK: affine.load %[[BUF]]
%3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
%4 = arith.cmpi slt, %3, %c3_i32 : i32
affine.store %4, %alloc_0[%arg1, %arg2, %arg3] : memref<1x2x1xi1>
}
}
}
// CHECK: memref.dealloc %[[BUF]]
%2 = memref.load %alloc_0[%arg0, %arg0, %arg0] : memref<1x2x1xi1>
cf.cond_br %2, ^bb2, ^bb3
^bb2: // pred: ^bb1
// CHECK: ^bb2
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 2 {
affine.for %arg3 = 0 to 1 {
// Ensure that this reference isn't replaced.
%3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
// CHECK: affine.load %[[MEM]]
%4 = arith.addi %3, %c1_i32 : i32
affine.store %4, %alloc_1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
}
}
}
cf.br ^bb1(%alloc_1 : memref<1x2x1xi32>)
^bb3: // pred: ^bb1
return %1 : memref<1x2x1xi32>
}