Skip to content

Commit 2336b0d

Browse files
committed
[MLIR][Affine] Loop fusion in a block containing Linalg op
Handle region-holding operators implementing Linalg interface in MemRefDependenceGraph.
1 parent 6ce41db commit 2336b0d

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
22+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2223
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2324
#include "mlir/IR/IntegerSet.h"
2425
#include "mlir/Interfaces/CallInterfaces.h"
@@ -252,6 +253,9 @@ bool MemRefDependenceGraph::init() {
252253
// Create graph nodes.
253254
DenseMap<Operation *, unsigned> forToNodeMap;
254255
for (Operation &op : block) {
256+
bool hasUnsupportedRegion =
257+
op.getNumRegions() != 0 &&
258+
!isa<RegionBranchOpInterface, linalg::LinalgOp>(op);
255259
if (auto forOp = dyn_cast<AffineForOp>(op)) {
256260
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
257261
if (!node)
@@ -277,8 +281,7 @@ bool MemRefDependenceGraph::init() {
277281
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
278282
if (!node)
279283
return false;
280-
} else if (!isMemoryEffectFree(&op) &&
281-
(op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
284+
} else if (!isMemoryEffectFree(&op) && !hasUnsupportedRegion) {
282285
// Create graph node for top-level op unless it is known to be
283286
// memory-effect free. This covers all unknown/unregistered ops,
284287
// non-affine ops with memory effects, and region-holding ops with a
@@ -287,7 +290,7 @@ bool MemRefDependenceGraph::init() {
287290
Node *node = addNodeToMDG(&op, *this, memrefAccesses);
288291
if (!node)
289292
return false;
290-
} else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
293+
} else if (hasUnsupportedRegion) {
291294
// Return false if non-handled/unknown region-holding ops are found. We
292295
// won't know what such ops do or what its regions mean; for e.g., it may
293296
// not be an imperative op.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,36 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
548548

549549
// -----
550550

551+
// Check that presence of a Linalg operator in a block does not prevent
552+
// fusion from happening in this block.
553+
554+
// ALL-LABEL: func @fusion_in_block_containing_linalg
555+
func.func @fusion_in_block_containing_linalg(%arg0: memref<5xi8>, %arg1: memref<5xi8>) {
556+
%c15_i8 = arith.constant 15 : i8
557+
%alloc = memref.alloc() : memref<5xi8>
558+
affine.for %arg3 = 0 to 5 {
559+
affine.store %c15_i8, %alloc[%arg3] : memref<5xi8>
560+
}
561+
affine.for %arg3 = 0 to 5 {
562+
%0 = affine.load %alloc[%arg3] : memref<5xi8>
563+
%1 = affine.load %arg0[%arg3] : memref<5xi8>
564+
%2 = arith.muli %0, %1 : i8
565+
affine.store %2, %alloc[%arg3] : memref<5xi8>
566+
}
567+
// ALL: affine.for
568+
// ALL-NEXT: affine.store
569+
// ALL-NEXT: affine.load
570+
// ALL-NEXT: affine.load
571+
// ALL-NEXT: arith.muli
572+
// ALL-NEXT: affine.store
573+
// ALL-NEXT: }
574+
linalg.elemwise_binary ins(%alloc, %alloc: memref<5xi8>, memref<5xi8>) outs(%arg1: memref<5xi8>)
575+
// ALL-NEXT: linalg.elemwise_binary
576+
return
577+
}
578+
579+
// -----
580+
551581
// From https://github.com/llvm/llvm-project/issues/54541
552582

553583
#map = affine_map<(d0) -> (d0 mod 65536)>

0 commit comments

Comments
 (0)