-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][Affine] Fix affine loop fusion with vector ops #115849, #120227 #122799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: None (brod4910) ChangesPatch is 21.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122799.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..5039162ac5bc06 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
@@ -23,14 +25,22 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
#include <iomanip>
+#include <iostream>
#include <optional>
#include <sstream>
@@ -177,13 +187,115 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
+static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
+ DenseSet<Value> &producerConsumerMemrefs,
+ MemRefDependenceGraph *mdg) {
+ SmallVector<Operation *> storeOps;
+ SmallVector<Operation *> loadOps;
+
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+
+ for (Value memref : producerConsumerMemrefs) {
+ srcNode->getStoreOpsForMemref(memref, &storeOps);
+ dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+ for (Operation *storeOp : storeOps) {
+ Value storeValue =
+ cast<AffineWriteOpInterface>(storeOp).getValueToStore();
+ auto storeShapedType = dyn_cast<ShapedType>(storeValue.getType());
+
+ if (!storeShapedType)
+ continue;
+
+ for (Operation *loadOp : loadOps) {
+ Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+ auto loadShapedType = dyn_cast<ShapedType>(loadValue.getType());
+
+ if (!loadShapedType)
+ continue;
+
+ for (int i = 0; i < loadShapedType.getRank(); ++i) {
+ auto loadDim = loadShapedType.getDimSize(i);
+ auto storeDim = storeShapedType.getDimSize(i);
+
+ if (loadDim > storeDim)
+ return false;
+ }
+ }
+ }
+
+ storeOps.clear();
+ loadOps.clear();
+ }
+
+ return true;
+}
+
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
+static bool checkVectorLoadStoreOps(unsigned srcId, unsigned dstId,
+ DenseSet<Value> &producerConsumerMemrefs,
+ MemRefDependenceGraph *mdg) {
+ SmallVector<Operation *> storeOps;
+ SmallVector<Operation *> loadOps;
+
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+
+ for (Value memref : producerConsumerMemrefs) {
+ srcNode->getStoreOpsForMemref(memref, &storeOps);
+ dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+ for (Operation *storeOp : storeOps) {
+ auto vectorStoreOp = dyn_cast<AffineVectorStoreOp>(storeOp);
+
+ if (!vectorStoreOp)
+ continue;
+
+ auto storeVecType = vectorStoreOp.getVectorType();
+
+ for (Operation *loadOp : loadOps) {
+ auto vectorLoadOp = dyn_cast<AffineVectorLoadOp>(loadOp);
+
+ if (!vectorLoadOp)
+ return false;
+
+ auto loadVecType = vectorLoadOp.getVectorType();
+
+ if (loadVecType.getRank() != storeVecType.getRank())
+ return false;
+
+ for (int i = 0; i < loadVecType.getRank(); ++i) {
+ auto loadDim = loadVecType.getDimSize(i);
+ auto storeDim = storeVecType.getDimSize(i);
+
+ if (loadDim > storeDim)
+ return false;
+ }
+ }
+ }
+
+ storeOps.clear();
+ loadOps.clear();
+ }
+
+ return true;
+}
+
/// A memref escapes in the context of the fusion pass if either:
/// 1. it (or its alias) is a block argument, or
/// 2. created by an op not known to guarantee alias freedom,
-/// 3. it (or its alias) are used by ops other than affine dereferencing ops
-/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
-/// terminator ops, etc.); such ops do not deference the memref in an affine
-/// way.
+/// 3. it (or its alias) are used by ops other than affine dereferencing
+/// ops (e.g., by call op, memref load/store ops, alias creating ops,
+/// unknown ops, terminator ops, etc.); such ops do not deference the
+/// memref in an affine way.
static bool isEscapingMemref(Value memref, Block *block) {
Operation *defOp = memref.getDefiningOp();
// Check if 'memref' is a block argument.
@@ -237,6 +349,57 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}
+static Value createPrivateVectorOpMemRef(
+ AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth,
+ std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) {
+ Operation *forInst = forOp.getOperation();
+
+ // Create builder to insert alloc op just before 'forOp'.
+ OpBuilder b(forInst);
+ // Builder to create constants at the top level.
+ OpBuilder top(forInst->getParentRegion());
+ // Create new memref type based on slice bounds.
+ auto srcAffineOp = cast<AffineWriteOpInterface>(srcStoreOpInst);
+
+ auto oldMemRef = srcAffineOp.getMemRef();
+ auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
+ unsigned rank = oldMemRefType.getRank();
+
+ auto srcOpResult = srcAffineOp.getValueToStore();
+ auto shapedType = dyn_cast<ShapedType>(srcOpResult.getType());
+
+ // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
+ // by 'srcStoreOpInst'.
+ auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
+ assert(eltSize && "memrefs with size elt types expected");
+ uint64_t bufSize = *eltSize * shapedType.getNumElements();
+ unsigned newMemSpace;
+ if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
+ newMemSpace = *fastMemorySpace;
+ } else {
+ newMemSpace = oldMemRefType.getMemorySpaceAsInt();
+ }
+
+ auto newMemRefType = MemRefType::get(
+ shapedType.getShape(), oldMemRefType.getElementType(), {}, newMemSpace);
+
+ // Create new private memref for fused loop 'forOp'. 'newShape' is always
+ // a constant shape.
+ Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
+
+ auto indexRemap = AffineMap::getMultiDimIdentityMap(rank, forOp.getContext());
+
+ // Replace all users of 'oldMemRef' with 'newMemRef'.
+ LogicalResult res =
+ replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/{},
+ /*domOpFilter=*/&*forOp.getBody()->begin());
+ assert(succeeded(res) &&
+ "replaceAllMemrefUsesWith should always succeed here");
+ return newMemRef;
+}
+
// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -306,9 +469,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
} else {
newMemSpace = oldMemRefType.getMemorySpaceAsInt();
}
+
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
{}, newMemSpace);
-
// Create new private memref for fused loop 'forOp'. 'newShape' is always
// a constant shape.
// TODO: Create/move alloc ops for private memrefs closer to their
@@ -322,7 +485,6 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
remapExprs.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
-
auto remapExpr =
simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
remapExprs.push_back(remapExpr);
@@ -340,6 +502,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
+
return newMemRef;
}
@@ -516,6 +679,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
// nest slice 'slice' were to be inserted into the dst loop nest at loop
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
+
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
&slice))) {
LLVM_DEBUG(llvm::dbgs()
@@ -798,7 +962,6 @@ struct GreedyFusion {
/// No fusion is performed when producers with a user count greater than
/// `maxSrcUserCount` for any of the memrefs involved.
void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
return;
@@ -858,8 +1021,17 @@ struct GreedyFusion {
}))
continue;
- // Gather memrefs in 'srcNode' that are written and escape out of the
- // block (e.g., memref block arguments, returned memrefs,
+ if (!checkVectorLoadStoreOps(srcId, dstId, producerConsumerMemrefs,
+ mdg)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Can't fuse: vector loop fusion invalid due to either "
+ "src or dst ops are not affine vector ops or load "
+ "dependent on a larger store region\n");
+ continue;
+ }
+
+ // Gather memrefs in 'srcNode' that are written and escape out of
+ // the block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
@@ -907,7 +1079,6 @@ struct GreedyFusion {
dstMemrefOps.push_back(op);
unsigned dstLoopDepthTest =
getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
-
// Check the feasibility of fusing src loop nest into dst loop nest
// at loop depths in range [1, dstLoopDepthTest].
unsigned maxLegalFusionDepth = 0;
@@ -976,9 +1147,6 @@ struct GreedyFusion {
if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
removeSrcNode)) {
// Create a private version of this memref.
- LLVM_DEBUG(llvm::dbgs()
- << "Creating private memref for " << memref << '\n');
- // Create a private version of this memref.
privateMemrefs.insert(memref);
}
}
@@ -1019,9 +1187,19 @@ struct GreedyFusion {
// private memref footprint.
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
- Value newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
+ Operation *srcStoreOpInst = storesForMemref[0];
+ Value newMemRef;
+
+ if (isa<AffineVectorLoadOp, AffineVectorStoreOp>(srcStoreOpInst)) {
+ newMemRef = createPrivateVectorOpMemRef(
+ dstAffineForOp, srcStoreOpInst, bestDstLoopDepth,
+ fastMemorySpace, localBufSizeThreshold);
+ } else {
+ newMemRef = createPrivateMemRef(dstAffineForOp, srcStoreOpInst,
+ bestDstLoopDepth, fastMemorySpace,
+ localBufSizeThreshold);
+ }
+
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6..bba2d13cc14973 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
// Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
@@ -285,3 +286,147 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}
+
+// -----
+
+// Basic test for not fusing loops where a vector load depends on
+// the entire result of a previous loop. store shape < load shape
+
+// CHECK-LABEL: func @should_not_fuse_across_memref_store_load_bounds
+func.func @should_not_fuse_across_memref_store_load_bounds() {
+ %a = memref.alloc() : memref<64x512xf32>
+ %b = memref.alloc() : memref<64x512xf32>
+ %c = memref.alloc() : memref<64x512xf32>
+ %d = memref.alloc() : memref<64x4096xf32>
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ }
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+ %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+ affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ }
+
+ return
+}
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8
+// CHECK: %[[lhs:.*]] = affine.vector_load %[[a]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:.*]] = affine.vector_load %[[b]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], %[[c]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: affine.for %[[j_2:.*]] = 0 to 8
+// CHECK: %[[lhs_2:.*]] = affine.vector_load %[[c]][0, 0] : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: %[[rhs_2:.*]] = affine.vector_load %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x512xf32>
+// CHECK: affine.vector_store %[[res_2]], %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: return
+
+// -----
+
+// Basic test for not fusing loops where the dependencies involve
+// an affine vector store and affine loads
+
+// CHECK-LABEL: func @should_not_fuse_vector_store_non_vector_load
+func.func @should_not_fuse_vector_store_non_vector_load() -> memref<64x4096xf32> {
+ %c0 = arith.constant 0 : index
+ %a = memref.alloc() : memref<64x512xf32>
+ %b = memref.alloc() : memref<64x512xf32>
+ %c = memref.alloc() : memref<64x512xf32>
+ %d = memref.alloc() : memref<64x4096xf32>
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ }
+
+ affine.for %k = 0 to 64 {
+ affine.for %m = 0 to 4096 {
+ affine.for %l = 0 to 512 {
+ %lhs = affine.load %c[%k, %l] : memref<64x512xf32>
+ %rhs = affine.load %d[%k, %m] : memref<64x4096xf32>
+ %res = arith.subf %lhs, %rhs : f32
+ affine.store %res, %d[%k, %m] : memref<64x4096xf32>
+ }
+ }
+ }
+
+ return %d : memref<64x4096xf32>
+}
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8 {
+// CHECK: %[[lhs:.*]] = affine.vector_load %[[a]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:.*]] = affine.vector_load %[[b]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], %[[c]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %[[k:.*]] = 0 to 64 {
+// CHECK: affine.for %[[l:.*]] = 0 to 4096 {
+// CHECK: affine.for %[[m:.*]] = 0 to 512 {
+// CHECK: %[[lhs_2:.*]] = affine.load %[[c]][%[[k]], %[[m]]] : memref<64x512xf32>
+// CHECK: %[[rhs_2:.*]] = affine.load %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK: %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : f32
+// CHECK: affine.store %[[res_2]], %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[d]] : memref<64x4096xf32>
+
+// -----
+
+// Basic test for fusing loops where a vector load depends on
+// the partial result of a previous loop. store shape > load shape
+
+// CHECK-LABEL: func @should_fuse_across_memref_store_load_bounds
+func.func @should_fuse_across_memref_store_load_bounds() -> memref<64x4096xf32> {
+ %c0 = arith.constant 0 : index
+ %a = memref.alloc() : memref<64x512xf32>
+ %b = memref.alloc() : memref<64x512xf32>
+ %c = memref.alloc() : memref<64x512xf32>
+ %d = memref.alloc() : memref<64x4096xf32>
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ }
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+ %res = arith.subf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+ }
+ return %d : memref<64x4096xf32>
+}
+// CHECK: %[[private:.*]] = memref.alloc() : memref...
[truncated]
|
The However, this solution should cover the bases for now which is essentially just taking the shape of the vector and using that for the private memref. |
also a fix for this issue: #115989 |
@@ -516,6 +630,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, | |||
// nest slice 'slice' were to be inserted into the dst loop nest at loop | |||
// depth 'i'. | |||
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop white space change.
if (loadVecType.getRank() != storeVecType.getRank()) | ||
return false; | ||
|
||
for (int i = 0; i < loadVecType.getRank(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use unsigned i = 0, e = ...
form to prevent repeated evaluation.
@@ -306,9 +420,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, | |||
} else { | |||
newMemSpace = oldMemRefType.getMemorySpaceAsInt(); | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop white space.
@@ -858,8 +972,17 @@ struct GreedyFusion { | |||
})) | |||
continue; | |||
|
|||
// Gather memrefs in 'srcNode' that are written and escape out of the | |||
// block (e.g., memref block arguments, returned memrefs, | |||
if (!checkVectorLoadStoreOps(srcId, dstId, producerConsumerMemrefs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
areIncomptibleVectorLoadStore?
|
||
for (Operation *loadOp : loadOps) { | ||
auto vectorLoadOp = dyn_cast<AffineVectorLoadOp>(loadOp); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop blank lines like these - here, above and below.
LLVM_DEBUG(llvm::dbgs() | ||
<< "Creating private memref for " << memref << '\n'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these debug lines deleted?
/// Creates a private memref to be used by vector operations. | ||
/// TODO: The difference between this and 'createPrivateMemRef' is that | ||
/// the system for calculating the bounds and constraints doesn't | ||
/// support vector operations. Thus, we use the shape of the vector | ||
/// as our newly created private memref instead of using a constraint | ||
/// system. | ||
static Value createPrivateVectorOpMemRef( | ||
AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth, | ||
std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't create a new method duplicating some of the logic. Instead, unify the two, i.e., extend createPrivateMemRef
for this.
/// Creates a private memref to be used by vector operations. | ||
/// TODO: The difference between this and 'createPrivateMemRef' is that | ||
/// the system for calculating the bounds and constraints doesn't | ||
/// support vector operations. Thus, we use the shape of the vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this looks wrong. The vector shape can't be the shape of the private memref. We'll have to use the constraint system. You aren't assuming single element private memref, right?
affine.vector_store %res, %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32> | ||
} | ||
return %d : memref<64x4096xf32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All test cases are assuming single element private memrefs and so the vector shape would be the private memref shape. This isn't true in the general case. You'll need to multiply the vector shape with the private memref region dimension sizes computed. Test cases will have to be augmented - see another test case where the private memref isn't unit sized.
Fixes involving affine.vector operations and affine loop fusion