Skip to content

[MLIR][Affine] Fix affine loop fusion with vector ops #115849, #120227 #122799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 146 additions & 17 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Affine/Passes.h"

#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
Expand All @@ -23,14 +24,20 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <optional>
#include <sstream>

Expand Down Expand Up @@ -177,13 +184,68 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}

/// Performs two checks:
/// Firstly, checks if both src/dst ops are vector operations.
/// Secondly, the shapes of the loads and stores of each memref in
/// the producer/consumer chains. If the load shapes are larger
/// than the stores then we cannot fuse the loops. The loads
/// would have a dependency on the values stored.
static bool checkVectorLoadStoreOps(unsigned srcId, unsigned dstId,
DenseSet<Value> &producerConsumerMemrefs,
MemRefDependenceGraph *mdg) {
SmallVector<Operation *> storeOps;
SmallVector<Operation *> loadOps;

auto *srcNode = mdg->getNode(srcId);
auto *dstNode = mdg->getNode(dstId);

for (Value memref : producerConsumerMemrefs) {
srcNode->getStoreOpsForMemref(memref, &storeOps);
dstNode->getLoadOpsForMemref(memref, &loadOps);

for (Operation *storeOp : storeOps) {
auto vectorStoreOp = dyn_cast<AffineVectorStoreOp>(storeOp);

if (!vectorStoreOp)
continue;

auto storeVecType = vectorStoreOp.getVectorType();

for (Operation *loadOp : loadOps) {
auto vectorLoadOp = dyn_cast<AffineVectorLoadOp>(loadOp);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop blank lines like these - here, above and below.

if (!vectorLoadOp)
return false;

auto loadVecType = vectorLoadOp.getVectorType();

if (loadVecType.getRank() != storeVecType.getRank())
return false;

for (int i = 0; i < loadVecType.getRank(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use unsigned i = 0, e = ... form to prevent repeated evaluation.

auto loadDim = loadVecType.getDimSize(i);
auto storeDim = storeVecType.getDimSize(i);

if (loadDim > storeDim)
return false;
}
}
}

storeOps.clear();
loadOps.clear();
}

return true;
}

/// A memref escapes in the context of the fusion pass if either:
/// 1. it (or its alias) is a block argument, or
/// 2. created by an op not known to guarantee alias freedom,
/// 3. it (or its alias) are used by ops other than affine dereferencing ops
/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
/// terminator ops, etc.); such ops do not deference the memref in an affine
/// way.
/// 3. it (or its alias) are used by ops other than affine dereferencing
/// ops (e.g., by call op, memref load/store ops, alias creating ops,
/// unknown ops, terminator ops, etc.); such ops do not deference the
/// memref in an affine way.
static bool isEscapingMemref(Value memref, Block *block) {
Operation *defOp = memref.getDefiningOp();
// Check if 'memref' is a block argument.
Expand Down Expand Up @@ -237,6 +299,58 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}

/// Creates a private memref to be used by vector operations.
/// TODO: The difference between this and 'createPrivateMemRef' is that
/// the system for calculating the bounds and constraints doesn't
/// support vector operations. Thus, we use the shape of the vector
Comment on lines +302 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this looks wrong. The vector shape can't be the shape of the private memref. We'll have to use the constraint system. You aren't assuming single element private memref, right?

/// as our newly created private memref instead of using a constraint
/// system.
static Value createPrivateVectorOpMemRef(
AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth,
std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) {
Comment on lines +302 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't create a new method duplicating some of the logic. Instead, unify the two, i.e., extend createPrivateMemRef for this.

Operation *forInst = forOp.getOperation();

// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentRegion());

auto srcAffineOp = cast<AffineVectorStoreOp>(srcStoreOpInst);

auto oldMemRef = srcAffineOp.getMemRef();
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();

auto vecType = srcAffineOp.getVectorType();

auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
assert(eltSize && "memrefs with size elt types expected");
uint64_t bufSize = *eltSize * vecType.getNumElements();
unsigned newMemSpace;
if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
newMemSpace = *fastMemorySpace;
} else {
newMemSpace = oldMemRefType.getMemorySpaceAsInt();
}

auto newMemRefType = MemRefType::get(
vecType.getShape(), oldMemRefType.getElementType(), {}, newMemSpace);

Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);

auto indexRemap = AffineMap::getMultiDimIdentityMap(rank, forOp.getContext());

// Replace all users of 'oldMemRef' with 'newMemRef'.
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/&*forOp.getBody()->begin());
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
return newMemRef;
}

// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
Expand Down Expand Up @@ -306,9 +420,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
} else {
newMemSpace = oldMemRefType.getMemorySpaceAsInt();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop white space.

auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
{}, newMemSpace);

// Create new private memref for fused loop 'forOp'. 'newShape' is always
// a constant shape.
// TODO: Create/move alloc ops for private memrefs closer to their
Expand All @@ -322,7 +436,6 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
remapExprs.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);

auto remapExpr =
simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
remapExprs.push_back(remapExpr);
Expand All @@ -340,6 +453,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;

return newMemRef;
}

Expand Down Expand Up @@ -516,6 +630,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
// nest slice 'slice' were to be inserted into the dst loop nest at loop
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop white space change.

if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
&slice))) {
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -798,7 +913,6 @@ struct GreedyFusion {
/// No fusion is performed when producers with a user count greater than
/// `maxSrcUserCount` for any of the memrefs involved.
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
return;
Expand Down Expand Up @@ -858,8 +972,17 @@ struct GreedyFusion {
}))
continue;

// Gather memrefs in 'srcNode' that are written and escape out of the
// block (e.g., memref block arguments, returned memrefs,
if (!checkVectorLoadStoreOps(srcId, dstId, producerConsumerMemrefs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

areIncomptibleVectorLoadStore?

mdg)) {
LLVM_DEBUG(llvm::dbgs()
<< "Can't fuse: vector loop fusion invalid due to either "
"src or dst ops are not affine vector ops or load "
"dependent on a larger store region\n");
continue;
}

// Gather memrefs in 'srcNode' that are written and escape out of
// the block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
Expand Down Expand Up @@ -907,7 +1030,6 @@ struct GreedyFusion {
dstMemrefOps.push_back(op);
unsigned dstLoopDepthTest =
getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;

// Check the feasibility of fusing src loop nest into dst loop nest
// at loop depths in range [1, dstLoopDepthTest].
unsigned maxLegalFusionDepth = 0;
Expand Down Expand Up @@ -976,9 +1098,6 @@ struct GreedyFusion {
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
removeSrcNode)) {
// Create a private version of this memref.
LLVM_DEBUG(llvm::dbgs()
<< "Creating private memref for " << memref << '\n');
Comment on lines -979 to -980
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these debug lines deleted?

// Create a private version of this memref.
privateMemrefs.insert(memref);
}
}
Expand Down Expand Up @@ -1019,9 +1138,19 @@ struct GreedyFusion {
// private memref footprint.
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
Value newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
Operation *srcStoreOpInst = storesForMemref[0];
Value newMemRef;

if (isa<AffineVectorLoadOp, AffineVectorStoreOp>(srcStoreOpInst)) {
newMemRef = createPrivateVectorOpMemRef(
dstAffineForOp, srcStoreOpInst, bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
} else {
newMemRef = createPrivateMemRef(dstAffineForOp, srcStoreOpInst,
bestDstLoopDepth, fastMemorySpace,
localBufSizeThreshold);
}

// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
Expand Down
Loading
Loading