-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][Affine] Fix affine loop fusion with vector ops #115849, #120227 #122799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b1aa2c7
e2d8398
c49200c
267bd4b
50a9a20
d9ad45b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" | ||
#include "mlir/Dialect/Affine/Passes.h" | ||
|
||
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" | ||
|
@@ -23,14 +24,20 @@ | |
#include "mlir/IR/AffineExpr.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/Types.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/DenseSet.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Support/CommandLine.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
#include <cstdint> | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <optional> | ||
#include <sstream> | ||
|
||
|
@@ -177,13 +184,68 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, | |
producerConsumerMemrefs); | ||
} | ||
|
||
/// Performs two checks: | ||
/// Firstly, checks if both src/dst ops are vector operations. | ||
/// Secondly, the shapes of the loads and stores of each memref in | ||
/// the producer/consumer chains. If the load shapes are larger | ||
/// than the stores then we cannot fuse the loops. The loads | ||
/// would have a dependency on the values stored. | ||
static bool checkVectorLoadStoreOps(unsigned srcId, unsigned dstId, | ||
DenseSet<Value> &producerConsumerMemrefs, | ||
MemRefDependenceGraph *mdg) { | ||
SmallVector<Operation *> storeOps; | ||
SmallVector<Operation *> loadOps; | ||
|
||
auto *srcNode = mdg->getNode(srcId); | ||
auto *dstNode = mdg->getNode(dstId); | ||
|
||
for (Value memref : producerConsumerMemrefs) { | ||
srcNode->getStoreOpsForMemref(memref, &storeOps); | ||
dstNode->getLoadOpsForMemref(memref, &loadOps); | ||
|
||
for (Operation *storeOp : storeOps) { | ||
auto vectorStoreOp = dyn_cast<AffineVectorStoreOp>(storeOp); | ||
|
||
if (!vectorStoreOp) | ||
continue; | ||
|
||
auto storeVecType = vectorStoreOp.getVectorType(); | ||
|
||
for (Operation *loadOp : loadOps) { | ||
auto vectorLoadOp = dyn_cast<AffineVectorLoadOp>(loadOp); | ||
|
||
if (!vectorLoadOp) | ||
return false; | ||
|
||
auto loadVecType = vectorLoadOp.getVectorType(); | ||
|
||
if (loadVecType.getRank() != storeVecType.getRank()) | ||
return false; | ||
|
||
for (int i = 0; i < loadVecType.getRank(); ++i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
auto loadDim = loadVecType.getDimSize(i); | ||
auto storeDim = storeVecType.getDimSize(i); | ||
|
||
if (loadDim > storeDim) | ||
return false; | ||
} | ||
} | ||
} | ||
|
||
storeOps.clear(); | ||
loadOps.clear(); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
/// A memref escapes in the context of the fusion pass if either: | ||
/// 1. it (or its alias) is a block argument, or | ||
/// 2. created by an op not known to guarantee alias freedom, | ||
/// 3. it (or its alias) are used by ops other than affine dereferencing ops | ||
/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, | ||
/// terminator ops, etc.); such ops do not deference the memref in an affine | ||
/// way. | ||
/// 3. it (or its alias) are used by ops other than affine dereferencing | ||
/// ops (e.g., by call op, memref load/store ops, alias creating ops, | ||
/// unknown ops, terminator ops, etc.); such ops do not deference the | ||
/// memref in an affine way. | ||
static bool isEscapingMemref(Value memref, Block *block) { | ||
Operation *defOp = memref.getDefiningOp(); | ||
// Check if 'memref' is a block argument. | ||
|
@@ -237,6 +299,58 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { | |
node->op = newRootForOp; | ||
} | ||
|
||
/// Creates a private memref to be used by vector operations. | ||
/// TODO: The difference between this and 'createPrivateMemRef' is that | ||
/// the system for calculating the bounds and constraints doesn't | ||
/// support vector operations. Thus, we use the shape of the vector | ||
Comment on lines
+302
to
+305
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this looks wrong. The vector shape can't be the shape of the private memref. We'll have to use the constraint system. You aren't assuming single element private memref, right? |
||
/// as our newly created private memref instead of using a constraint | ||
/// system. | ||
static Value createPrivateVectorOpMemRef( | ||
AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, | ||
std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) { | ||
Comment on lines
+302
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't create a new method duplicating some of the logic. Instead, unify the two, i.e., extend |
||
Operation *forInst = forOp.getOperation(); | ||
|
||
// Create builder to insert alloc op just before 'forOp'. | ||
OpBuilder b(forInst); | ||
// Builder to create constants at the top level. | ||
OpBuilder top(forInst->getParentRegion()); | ||
|
||
auto srcAffineOp = cast<AffineVectorStoreOp>(srcStoreOpInst); | ||
|
||
auto oldMemRef = srcAffineOp.getMemRef(); | ||
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType()); | ||
unsigned rank = oldMemRefType.getRank(); | ||
|
||
auto vecType = srcAffineOp.getVectorType(); | ||
|
||
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType); | ||
assert(eltSize && "memrefs with size elt types expected"); | ||
uint64_t bufSize = *eltSize * vecType.getNumElements(); | ||
unsigned newMemSpace; | ||
if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) { | ||
newMemSpace = *fastMemorySpace; | ||
} else { | ||
newMemSpace = oldMemRefType.getMemorySpaceAsInt(); | ||
} | ||
|
||
auto newMemRefType = MemRefType::get( | ||
vecType.getShape(), oldMemRefType.getElementType(), {}, newMemSpace); | ||
|
||
Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); | ||
|
||
auto indexRemap = AffineMap::getMultiDimIdentityMap(rank, forOp.getContext()); | ||
|
||
// Replace all users of 'oldMemRef' with 'newMemRef'. | ||
LogicalResult res = | ||
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, | ||
/*extraOperands=*/{}, | ||
/*symbolOperands=*/{}, | ||
/*domOpFilter=*/&*forOp.getBody()->begin()); | ||
assert(succeeded(res) && | ||
"replaceAllMemrefUsesWith should always succeed here"); | ||
return newMemRef; | ||
} | ||
|
||
// Creates and returns a private (single-user) memref for fused loop rooted | ||
// at 'forOp', with (potentially reduced) memref size based on the | ||
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. | ||
|
@@ -306,9 +420,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, | |
} else { | ||
newMemSpace = oldMemRefType.getMemorySpaceAsInt(); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop white space. |
||
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), | ||
{}, newMemSpace); | ||
|
||
// Create new private memref for fused loop 'forOp'. 'newShape' is always | ||
// a constant shape. | ||
// TODO: Create/move alloc ops for private memrefs closer to their | ||
|
@@ -322,7 +436,6 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, | |
remapExprs.reserve(rank); | ||
for (unsigned i = 0; i < rank; i++) { | ||
auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); | ||
|
||
auto remapExpr = | ||
simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); | ||
remapExprs.push_back(remapExpr); | ||
|
@@ -340,6 +453,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, | |
assert(succeeded(res) && | ||
"replaceAllMemrefUsesWith should always succeed here"); | ||
(void)res; | ||
|
||
return newMemRef; | ||
} | ||
|
||
|
@@ -516,6 +630,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, | |
// nest slice 'slice' were to be inserted into the dst loop nest at loop | ||
// depth 'i'. | ||
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop white space change. |
||
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, | ||
&slice))) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
|
@@ -798,7 +913,6 @@ struct GreedyFusion { | |
/// No fusion is performed when producers with a user count greater than | ||
/// `maxSrcUserCount` for any of the memrefs involved. | ||
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { | ||
LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); | ||
// Skip if this node was removed (fused into another node). | ||
if (mdg->nodes.count(dstId) == 0) | ||
return; | ||
|
@@ -858,8 +972,17 @@ struct GreedyFusion { | |
})) | ||
continue; | ||
|
||
// Gather memrefs in 'srcNode' that are written and escape out of the | ||
// block (e.g., memref block arguments, returned memrefs, | ||
if (!checkVectorLoadStoreOps(srcId, dstId, producerConsumerMemrefs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. areIncomptibleVectorLoadStore? |
||
mdg)) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Can't fuse: vector loop fusion invalid due to either " | ||
"src or dst ops are not affine vector ops or load " | ||
"dependent on a larger store region\n"); | ||
continue; | ||
} | ||
|
||
// Gather memrefs in 'srcNode' that are written and escape out of | ||
// the block (e.g., memref block arguments, returned memrefs, | ||
// memrefs passed to function calls, etc.). | ||
DenseSet<Value> srcEscapingMemRefs; | ||
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); | ||
|
@@ -907,7 +1030,6 @@ struct GreedyFusion { | |
dstMemrefOps.push_back(op); | ||
unsigned dstLoopDepthTest = | ||
getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops; | ||
|
||
// Check the feasibility of fusing src loop nest into dst loop nest | ||
// at loop depths in range [1, dstLoopDepthTest]. | ||
unsigned maxLegalFusionDepth = 0; | ||
|
@@ -976,9 +1098,6 @@ struct GreedyFusion { | |
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, | ||
removeSrcNode)) { | ||
// Create a private version of this memref. | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Creating private memref for " << memref << '\n'); | ||
Comment on lines
-979
to
-980
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these debug lines deleted? |
||
// Create a private version of this memref. | ||
privateMemrefs.insert(memref); | ||
} | ||
} | ||
|
@@ -1019,9 +1138,19 @@ struct GreedyFusion { | |
// private memref footprint. | ||
SmallVector<Operation *, 4> &storesForMemref = | ||
memrefToStoresPair.second; | ||
Value newMemRef = createPrivateMemRef( | ||
dstAffineForOp, storesForMemref[0], bestDstLoopDepth, | ||
fastMemorySpace, localBufSizeThreshold); | ||
Operation *srcStoreOpInst = storesForMemref[0]; | ||
Value newMemRef; | ||
|
||
if (isa<AffineVectorLoadOp, AffineVectorStoreOp>(srcStoreOpInst)) { | ||
newMemRef = createPrivateVectorOpMemRef( | ||
dstAffineForOp, srcStoreOpInst, bestDstLoopDepth, | ||
fastMemorySpace, localBufSizeThreshold); | ||
} else { | ||
newMemRef = createPrivateMemRef(dstAffineForOp, srcStoreOpInst, | ||
bestDstLoopDepth, fastMemorySpace, | ||
localBufSizeThreshold); | ||
} | ||
|
||
// Create new node in dependence graph for 'newMemRef' alloc op. | ||
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); | ||
// Add edge from 'newMemRef' node to dstNode. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop blank lines like these - here, above and below.