1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
1717#include " mlir/Dialect/SCF/IR/SCF.h"
1818#include " mlir/Dialect/SCF/Transforms/Transforms.h"
19+ #include " mlir/Dialect/SCF/Utils/Utils.h"
1920#include " mlir/IR/Builders.h"
2021#include " mlir/IR/IRMapping.h"
2122#include " mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
3031using namespace mlir ;
3132using namespace mlir ::scf;
3233
33- // / Verify there are no nested ParallelOps.
34- static bool hasNestedParallelOp (ParallelOp ploop) {
35- auto walkResult =
36- ploop.getBody ()->walk ([](ParallelOp) { return WalkResult::interrupt (); });
37- return walkResult.wasInterrupted ();
38- }
39-
40- // / Verify equal iteration spaces.
41- static bool equalIterationSpaces (ParallelOp firstPloop,
42- ParallelOp secondPloop) {
43- if (firstPloop.getNumLoops () != secondPloop.getNumLoops ())
44- return false ;
45-
46- auto matchOperands = [&](const OperandRange &lhs,
47- const OperandRange &rhs) -> bool {
48- // TODO: Extend this to support aliases and equal constants.
49- return std::equal (lhs.begin (), lhs.end (), rhs.begin ());
50- };
51- return matchOperands (firstPloop.getLowerBound (),
52- secondPloop.getLowerBound ()) &&
53- matchOperands (firstPloop.getUpperBound (),
54- secondPloop.getUpperBound ()) &&
55- matchOperands (firstPloop.getStep (), secondPloop.getStep ());
56- }
57-
58- // / Checks if the parallel loops have mixed access to the same buffers. Returns
59- // / `true` if the first parallel loop writes to the same indices that the second
60- // / loop reads.
61- static bool haveNoReadsAfterWriteExceptSameIndex (
62- ParallelOp firstPloop, ParallelOp secondPloop,
63- const IRMapping &firstToSecondPloopIndices,
64- llvm::function_ref<bool (Value, Value)> mayAlias) {
65- DenseMap<Value, SmallVector<ValueRange, 1 >> bufferStores;
66- SmallVector<Value> bufferStoresVec;
67- firstPloop.getBody ()->walk ([&](memref::StoreOp store) {
68- bufferStores[store.getMemRef ()].push_back (store.getIndices ());
69- bufferStoresVec.emplace_back (store.getMemRef ());
70- });
71- auto walkResult = secondPloop.getBody ()->walk ([&](memref::LoadOp load) {
72- Value loadMem = load.getMemRef ();
73- // Stop if the memref is defined in secondPloop body. Careful alias analysis
74- // is needed.
75- auto *memrefDef = loadMem.getDefiningOp ();
76- if (memrefDef && memrefDef->getBlock () == load->getBlock ())
77- return WalkResult::interrupt ();
78-
79- for (Value store : bufferStoresVec)
80- if (store != loadMem && mayAlias (store, loadMem))
81- return WalkResult::interrupt ();
82-
83- auto write = bufferStores.find (loadMem);
84- if (write == bufferStores.end ())
85- return WalkResult::advance ();
86-
87- // Check that at last one store was retrieved
88- if (!write->second .size ())
89- return WalkResult::interrupt ();
90-
91- auto storeIndices = write->second .front ();
92-
93- // Multiple writes to the same memref are allowed only on the same indices
94- for (const auto &othStoreIndices : write->second ) {
95- if (othStoreIndices != storeIndices)
96- return WalkResult::interrupt ();
97- }
98-
99- // Check that the load indices of secondPloop coincide with store indices of
100- // firstPloop for the same memrefs.
101- auto loadIndices = load.getIndices ();
102- if (storeIndices.size () != loadIndices.size ())
103- return WalkResult::interrupt ();
104- for (int i = 0 , e = storeIndices.size (); i < e; ++i) {
105- if (firstToSecondPloopIndices.lookupOrDefault (storeIndices[i]) !=
106- loadIndices[i]) {
107- auto *storeIndexDefOp = storeIndices[i].getDefiningOp ();
108- auto *loadIndexDefOp = loadIndices[i].getDefiningOp ();
109- if (storeIndexDefOp && loadIndexDefOp) {
110- if (!isMemoryEffectFree (storeIndexDefOp))
111- return WalkResult::interrupt ();
112- if (!isMemoryEffectFree (loadIndexDefOp))
113- return WalkResult::interrupt ();
114- if (!OperationEquivalence::isEquivalentTo (
115- storeIndexDefOp, loadIndexDefOp,
116- [&](Value storeIndex, Value loadIndex) {
117- if (firstToSecondPloopIndices.lookupOrDefault (storeIndex) !=
118- firstToSecondPloopIndices.lookupOrDefault (loadIndex))
119- return failure ();
120- else
121- return success ();
122- },
123- /* markEquivalent=*/ nullptr ,
124- OperationEquivalence::Flags::IgnoreLocations)) {
125- return WalkResult::interrupt ();
126- }
127- } else
128- return WalkResult::interrupt ();
129- }
130- }
131- return WalkResult::advance ();
132- });
133- return !walkResult.wasInterrupted ();
134- }
135-
136- // / Analyzes dependencies in the most primitive way by checking simple read and
137- // / write patterns.
138- static LogicalResult
139- verifyDependencies (ParallelOp firstPloop, ParallelOp secondPloop,
140- const IRMapping &firstToSecondPloopIndices,
141- llvm::function_ref<bool (Value, Value)> mayAlias) {
142- if (!haveNoReadsAfterWriteExceptSameIndex (
143- firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144- return failure ();
145-
146- IRMapping secondToFirstPloopIndices;
147- secondToFirstPloopIndices.map (secondPloop.getBody ()->getArguments (),
148- firstPloop.getBody ()->getArguments ());
149- return success (haveNoReadsAfterWriteExceptSameIndex (
150- secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151- }
152-
153- static bool isFusionLegal (ParallelOp firstPloop, ParallelOp secondPloop,
154- const IRMapping &firstToSecondPloopIndices,
155- llvm::function_ref<bool (Value, Value)> mayAlias) {
156- return !hasNestedParallelOp (firstPloop) &&
157- !hasNestedParallelOp (secondPloop) &&
158- equalIterationSpaces (firstPloop, secondPloop) &&
159- succeeded (verifyDependencies (firstPloop, secondPloop,
160- firstToSecondPloopIndices, mayAlias));
161- }
162-
163- // / Prepends operations of firstPloop's body into secondPloop's body.
164- // / Updates secondPloop with new loop.
165- static void fuseIfLegal (ParallelOp firstPloop, ParallelOp &secondPloop,
166- OpBuilder builder,
167- llvm::function_ref<bool (Value, Value)> mayAlias) {
168- Block *block1 = firstPloop.getBody ();
169- Block *block2 = secondPloop.getBody ();
170- IRMapping firstToSecondPloopIndices;
171- firstToSecondPloopIndices.map (block1->getArguments (), block2->getArguments ());
172-
173- if (!isFusionLegal (firstPloop, secondPloop, firstToSecondPloopIndices,
174- mayAlias))
175- return ;
176-
177- DominanceInfo dom;
178- // We are fusing first loop into second, make sure there are no users of the
179- // first loop results between loops.
180- for (Operation *user : firstPloop->getUsers ())
181- if (!dom.properlyDominates (secondPloop, user, /* enclosingOpOk*/ false ))
182- return ;
183-
184- ValueRange inits1 = firstPloop.getInitVals ();
185- ValueRange inits2 = secondPloop.getInitVals ();
186-
187- SmallVector<Value> newInitVars (inits1.begin (), inits1.end ());
188- newInitVars.append (inits2.begin (), inits2.end ());
189-
190- IRRewriter b (builder);
191- b.setInsertionPoint (secondPloop);
192- auto newSecondPloop = b.create <ParallelOp>(
193- secondPloop.getLoc (), secondPloop.getLowerBound (),
194- secondPloop.getUpperBound (), secondPloop.getStep (), newInitVars);
195-
196- Block *newBlock = newSecondPloop.getBody ();
197- auto term1 = cast<ReduceOp>(block1->getTerminator ());
198- auto term2 = cast<ReduceOp>(block2->getTerminator ());
199-
200- b.inlineBlockBefore (block2, newBlock, newBlock->begin (),
201- newBlock->getArguments ());
202- b.inlineBlockBefore (block1, newBlock, newBlock->begin (),
203- newBlock->getArguments ());
204-
205- ValueRange results = newSecondPloop.getResults ();
206- if (!results.empty ()) {
207- b.setInsertionPointToEnd (newBlock);
208-
209- ValueRange reduceArgs1 = term1.getOperands ();
210- ValueRange reduceArgs2 = term2.getOperands ();
211- SmallVector<Value> newReduceArgs (reduceArgs1.begin (), reduceArgs1.end ());
212- newReduceArgs.append (reduceArgs2.begin (), reduceArgs2.end ());
213-
214- auto newReduceOp = b.create <scf::ReduceOp>(term2.getLoc (), newReduceArgs);
215-
216- for (auto &&[i, reg] : llvm::enumerate (llvm::concat<Region>(
217- term1.getReductions (), term2.getReductions ()))) {
218- Block &oldRedBlock = reg.front ();
219- Block &newRedBlock = newReduceOp.getReductions ()[i].front ();
220- b.inlineBlockBefore (&oldRedBlock, &newRedBlock, newRedBlock.begin (),
221- newRedBlock.getArguments ());
222- }
223-
224- firstPloop.replaceAllUsesWith (results.take_front (inits1.size ()));
225- secondPloop.replaceAllUsesWith (results.take_back (inits2.size ()));
226- }
227- term1->erase ();
228- term2->erase ();
229- firstPloop.erase ();
230- secondPloop.erase ();
231- secondPloop = newSecondPloop;
232- }
233-
23434void mlir::scf::naivelyFuseParallelOps (
23535 Region ®ion, llvm::function_ref<bool (Value, Value)> mayAlias) {
23636 OpBuilder b (region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
25959 }
26060 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
26161 for (int i = 0 , e = ploops.size (); i + 1 < e; ++i)
262- fuseIfLegal (ploops[i], ploops[i + 1 ], b, mayAlias);
62+ mlir:: fuseIfLegal (ploops[i], ploops[i + 1 ], b, mayAlias);
26363 }
26464 }
26565}
0 commit comments