Skip to content

Commit de1d351

Browse files
authored
[MLIR][Affine] Fix fusion private memref creation for multiple producer stores (#130365)
Fix private memref creation in affine fusion for the multiple producer store case. This scenario was not supported but not properly checked. Fixes: #120227
1 parent f2607df commit de1d351

File tree

3 files changed

+80
-11
lines changed

3 files changed

+80
-11
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,34 @@ static std::optional<double> getAdditionalComputeFraction(
328328
// Creates and returns a private (single-user) memref for fused loop rooted at
329329
// 'forOp', with (potentially reduced) memref size based on the memref region
330330
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
331-
// specifies the block in which the slice was/will be inserted.
331+
// specifies the block in which the slice was/will be inserted. The method
332+
// expects that all stores ops to the memref have the same access function.
333+
// Returns nullptr if the creation failed.
332334
static Value createPrivateMemRef(AffineForOp forOp,
333335
ArrayRef<Operation *> storeOps,
334336
unsigned dstLoopDepth,
335337
std::optional<unsigned> fastMemorySpace,
336338
Block *sliceInsertionBlock,
337339
uint64_t localBufSizeThreshold) {
338340
assert(!storeOps.empty() && "no source stores supplied");
341+
342+
// Check if all stores have the same access function; we only support this
343+
// case.
344+
// TODO: Use union of memref write regions to compute private memref footprint
345+
// for store ops with different access functions.
346+
if (storeOps.size() > 1 &&
347+
!std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348+
[](Operation *a, Operation *b) {
349+
MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350+
MemRefAccess bM(cast<AffineWriteOpInterface>(b));
351+
return aM == bM;
352+
})) {
353+
LLVM_DEBUG(llvm::dbgs()
354+
<< "Private memref creation unsupported for multiple producer "
355+
"stores with different access functions.\n");
356+
return nullptr;
357+
}
358+
339359
Operation *srcStoreOp = storeOps[0];
340360

341361
// Create builder to insert alloc op just before 'forOp'.
@@ -432,6 +452,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
432452
assert(succeeded(res) &&
433453
"replaceAllMemrefUsesWith should always succeed here");
434454
(void)res;
455+
LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
456+
<< '\n');
435457
return newMemRef;
436458
}
437459

@@ -1123,13 +1145,12 @@ struct GreedyFusion {
11231145
// loads and stores. Any reference to the original ones becomes
11241146
// invalid after this point.
11251147
for (auto &memrefToStoresPair : privateMemRefToStores) {
1126-
// TODO: Use union of memref write regions to compute
1127-
// private memref footprint.
1128-
SmallVector<Operation *, 4> &storesForMemref =
1129-
memrefToStoresPair.second;
1148+
ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
11301149
Value newMemRef = createPrivateMemRef(
11311150
dstAffineForOp, storesForMemref, bestDstLoopDepth,
11321151
fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1152+
if (!newMemRef)
1153+
continue;
11331154
// Create new node in dependence graph for 'newMemRef' alloc op.
11341155
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
11351156
// Add edge from 'newMemRef' node to dstNode.

mlir/test/Dialect/Affine/loop-fusion-2.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
3939

4040
// We can fuse source loop nest '%i0' into dst loop nest '%i2', but the
4141
// depth at which we can insert the src loop nest slice into the dst loop
42-
// lest must be decreased because of a loop carried dependence on loop '%i3'.
42+
// nest must be decreased because of a loop carried dependence on loop '%i3'.
4343
// As a result, the source loop nest is inserted at dst loop nest depth 1,
4444
// just above the loop with the carried dependence. In addition, the source
4545
// loop nest iteration bounds on its loop '%i1' are reduced to 1, so the
46-
// memref size can be reduced to 128x1xf32.
46+
// memref size can be reduced to 64x1xf32.
4747

48-
// CHECK: memref.alloc() : memref<64x1xf32>
48+
// In this case, since we have multiple producer stores (for %out) with
49+
// different access functions and we don't yet support private memref
50+
// computation in such cases, a 64x1 private memref isn't created.
51+
52+
// CHECK: memref.alloc() : memref<64x4xf32>
4953
// CHECK: affine.for %{{.*}} = 0 to 4 {
5054
// CHECK-NEXT: affine.for %{{.*}} = 0 to 64 {
51-
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
55+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x4xf32>
5256
// CHECK-NEXT: }
5357
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
5458
// CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
@@ -62,9 +66,9 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
6266
// CHECK-NEXT: }
6367
// CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
6468
// CHECK-NEXT: %{{.*}} = "op2"() : () -> f32
65-
// CHECK: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
69+
// CHECK: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
6670
// CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
67-
// CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
71+
// CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
6872
// CHECK-NEXT: }
6973
// CHECK-NEXT: }
7074
// CHECK-NEXT: }

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,47 @@ func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x13
622622
}
623623
func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
624624
func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64
625+
626+
// An unrolled loop nest. Fusion here should correctly fuse while preserving
627+
// dependences between store-load pairs of the same memref. A private memref
628+
// of size 1x1x1 can't be created.
629+
630+
// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @unrolled
631+
func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
632+
%alloc = memref.alloc() : memref<1x2x4xf32>
633+
affine.for %i = 0 to 1 {
634+
%0 = affine.load %arg0[0, 0] : memref<2x4xf32>
635+
%1 = affine.load %arg0[0, 1] : memref<2x4xf32>
636+
%2 = affine.load %arg0[0, 2] : memref<2x4xf32>
637+
%3 = affine.load %arg0[0, 3] : memref<2x4xf32>
638+
%4 = affine.load %arg0[1, 0] : memref<2x4xf32>
639+
%5 = affine.load %arg0[1, 1] : memref<2x4xf32>
640+
%6 = affine.load %arg0[1, 2] : memref<2x4xf32>
641+
%7 = affine.load %arg0[1, 3] : memref<2x4xf32>
642+
643+
affine.store %0, %alloc[0, 0, 0] : memref<1x2x4xf32>
644+
affine.store %1, %alloc[0, 0, 1] : memref<1x2x4xf32>
645+
affine.store %2, %alloc[0, 0, 2] : memref<1x2x4xf32>
646+
affine.store %3, %alloc[0, 0, 3] : memref<1x2x4xf32>
647+
affine.store %4, %alloc[0, 1, 0] : memref<1x2x4xf32>
648+
affine.store %5, %alloc[0, 1, 1] : memref<1x2x4xf32>
649+
affine.store %6, %alloc[0, 1, 2] : memref<1x2x4xf32>
650+
affine.store %7, %alloc[0, 1, 3] : memref<1x2x4xf32>
651+
}
652+
653+
affine.for %i = 0 to 2 {
654+
affine.for %j = 0 to 4 {
655+
%8 = affine.load %alloc[0, %i, %j] : memref<1x2x4xf32>
656+
%9 = arith.negf %8 : f32
657+
affine.store %9, %arg1[0, %i, %j] : memref<1x2x4xf32>
658+
}
659+
}
660+
// PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 2 {
661+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
662+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.load %{{.*}}[0, 0]
663+
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[1, 3]
664+
// PRODUCER-CONSUMER-MAXIMAL: affine.store %{{.*}}[0, 0, 0]
665+
// PRODUCER-CONSUMER-MAXIMAL: affine.store %{{.*}}[0, 1, 3]
666+
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
667+
return
668+
}

0 commit comments

Comments
 (0)