11
11
#include " flang/Optimizer/OpenMP/Utils.h"
12
12
#include " mlir/Analysis/SliceAnalysis.h"
13
13
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
14
+ #include " mlir/IR/IRMapping.h"
14
15
#include " mlir/Transforms/DialectConversion.h"
15
16
#include " mlir/Transforms/RegionUtils.h"
16
17
@@ -24,7 +25,67 @@ namespace flangomp {
24
25
25
26
namespace {
26
27
namespace looputils {
27
- using LoopNest = llvm::SetVector<fir::DoLoopOp>;
28
+ // / Stores info needed about the induction/iteration variable for each `do
29
+ // / concurrent` in a loop nest.
30
+ struct InductionVariableInfo {
31
+ // / The operation allocating memory for iteration variable.
32
+ mlir::Operation *iterVarMemDef;
33
+ };
34
+
35
+ using LoopNestToIndVarMap =
36
+ llvm::MapVector<fir::DoLoopOp, InductionVariableInfo>;
37
+
38
+ // / For the \p doLoop parameter, find the operation that declares its iteration
39
+ // / variable or allocates memory for it.
40
+ // /
41
+ // / For example, give the following loop:
42
+ // / ```
43
+ // / ...
44
+ // / %i:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : ...
45
+ // / ...
46
+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
47
+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
48
+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
49
+ // / ...
50
+ // / }
51
+ // / ```
52
+ // /
53
+ // / This function returns the `hlfir.declare` op for `%i`.
54
+ // /
55
+ // / Note: The current implementation is dependent on how flang emits loop
56
+ // / bodies; which is sufficient for the current simple test/use cases. If this
57
+ // / proves to be insufficient, this should be made more generic.
58
+ mlir::Operation *findLoopIterationVarMemDecl (fir::DoLoopOp doLoop) {
59
+ mlir::Value result = nullptr ;
60
+
61
+ // Checks if a StoreOp is updating the memref of the loop's iteration
62
+ // variable.
63
+ auto isStoringIV = [&](fir::StoreOp storeOp) {
64
+ // Direct store into the IV memref.
65
+ if (storeOp.getValue () == doLoop.getInductionVar ())
66
+ return true ;
67
+
68
+ // Indirect store into the IV memref.
69
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(
70
+ storeOp.getValue ().getDefiningOp ())) {
71
+ if (convertOp.getOperand () == doLoop.getInductionVar ())
72
+ return true ;
73
+ }
74
+
75
+ return false ;
76
+ };
77
+
78
+ for (mlir::Operation &op : doLoop) {
79
+ if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(op))
80
+ if (isStoringIV (storeOp)) {
81
+ result = storeOp.getMemref ();
82
+ break ;
83
+ }
84
+ }
85
+
86
+ assert (result != nullptr && result.getDefiningOp () != nullptr );
87
+ return result.getDefiningOp ();
88
+ }
28
89
29
90
// / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
30
91
// / there are no operations in \p outerloop's body other than:
@@ -116,11 +177,14 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
116
177
// / fails to recognize a certain nested loop as part of the nest it just returns
117
178
// / the parent loops it discovered before.
118
179
mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop,
119
- LoopNest &loopNest) {
180
+ LoopNestToIndVarMap &loopNest) {
120
181
assert (currentLoop.getUnordered ());
121
182
122
183
while (true ) {
123
- loopNest.insert (currentLoop);
184
+ loopNest.insert (
185
+ {currentLoop,
186
+ InductionVariableInfo{findLoopIterationVarMemDecl (currentLoop)}});
187
+
124
188
llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
125
189
126
190
for (auto nestedLoop : currentLoop.getRegion ().getOps <fir::DoLoopOp>())
@@ -152,26 +216,140 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
152
216
public:
153
217
using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern;
154
218
155
- DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice)
156
- : OpConversionPattern(context), mapToDevice(mapToDevice) {}
219
+ DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice,
220
+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip)
221
+ : OpConversionPattern(context), mapToDevice(mapToDevice),
222
+ concurrentLoopsToSkip (concurrentLoopsToSkip) {}
157
223
158
224
mlir::LogicalResult
159
225
matchAndRewrite (fir::DoLoopOp doLoop, OpAdaptor adaptor,
160
226
mlir::ConversionPatternRewriter &rewriter) const override {
161
- looputils::LoopNest loopNest;
227
+ if (mapToDevice)
228
+ return doLoop.emitError (
229
+ " not yet implemented: Mapping `do concurrent` loops to device" );
230
+
231
+ looputils::LoopNestToIndVarMap loopNest;
162
232
bool hasRemainingNestedLoops =
163
233
failed (looputils::collectLoopNest (doLoop, loopNest));
164
234
if (hasRemainingNestedLoops)
165
235
mlir::emitWarning (doLoop.getLoc (),
166
236
" Some `do concurent` loops are not perfectly-nested. "
167
237
" These will be serialized." );
168
238
169
- // TODO This will be filled in with the next PRs that upstreams the rest of
170
- // the ROCm implementaion.
239
+ mlir::IRMapping mapper;
240
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
241
+ mlir::omp::LoopNestOperands loopNestClauseOps;
242
+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
243
+ loopNestClauseOps);
244
+
245
+ mlir::omp::LoopNestOp ompLoopNest =
246
+ genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
247
+ /* isComposite=*/ mapToDevice);
248
+
249
+ rewriter.eraseOp (doLoop);
250
+
251
+ // Mark `unordered` loops that are not perfectly nested to be skipped from
252
+ // the legality check of the `ConversionTarget` since we are not interested
253
+ // in mapping them to OpenMP.
254
+ ompLoopNest->walk ([&](fir::DoLoopOp doLoop) {
255
+ if (doLoop.getUnordered ()) {
256
+ concurrentLoopsToSkip.insert (doLoop);
257
+ }
258
+ });
259
+
171
260
return mlir::success ();
172
261
}
173
262
263
+ private:
264
+ mlir::omp::ParallelOp genParallelOp (mlir::Location loc,
265
+ mlir::ConversionPatternRewriter &rewriter,
266
+ looputils::LoopNestToIndVarMap &loopNest,
267
+ mlir::IRMapping &mapper) const {
268
+ auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
269
+ rewriter.createBlock (¶llelOp.getRegion ());
270
+ rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
271
+
272
+ genLoopNestIndVarAllocs (rewriter, loopNest, mapper);
273
+ return parallelOp;
274
+ }
275
+
276
+ void genLoopNestIndVarAllocs (mlir::ConversionPatternRewriter &rewriter,
277
+ looputils::LoopNestToIndVarMap &loopNest,
278
+ mlir::IRMapping &mapper) const {
279
+
280
+ for (auto &[_, indVarInfo] : loopNest)
281
+ genInductionVariableAlloc (rewriter, indVarInfo.iterVarMemDef , mapper);
282
+ }
283
+
284
+ mlir::Operation *
285
+ genInductionVariableAlloc (mlir::ConversionPatternRewriter &rewriter,
286
+ mlir::Operation *indVarMemDef,
287
+ mlir::IRMapping &mapper) const {
288
+ assert (
289
+ indVarMemDef != nullptr &&
290
+ " Induction variable memdef is expected to have a defining operation." );
291
+
292
+ llvm::SmallSetVector<mlir::Operation *, 2 > indVarDeclareAndAlloc;
293
+ for (auto operand : indVarMemDef->getOperands ())
294
+ indVarDeclareAndAlloc.insert (operand.getDefiningOp ());
295
+ indVarDeclareAndAlloc.insert (indVarMemDef);
296
+
297
+ mlir::Operation *result;
298
+ for (mlir::Operation *opToClone : indVarDeclareAndAlloc)
299
+ result = rewriter.clone (*opToClone, mapper);
300
+
301
+ return result;
302
+ }
303
+
304
+ void genLoopNestClauseOps (
305
+ mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
306
+ looputils::LoopNestToIndVarMap &loopNest, mlir::IRMapping &mapper,
307
+ mlir::omp::LoopNestOperands &loopNestClauseOps) const {
308
+ assert (loopNestClauseOps.loopLowerBounds .empty () &&
309
+ " Loop nest bounds were already emitted!" );
310
+
311
+ auto populateBounds = [](mlir::Value var,
312
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
313
+ bounds.push_back (var.getDefiningOp ()->getResult (0 ));
314
+ };
315
+
316
+ for (auto &[doLoop, _] : loopNest) {
317
+ populateBounds (doLoop.getLowerBound (), loopNestClauseOps.loopLowerBounds );
318
+ populateBounds (doLoop.getUpperBound (), loopNestClauseOps.loopUpperBounds );
319
+ populateBounds (doLoop.getStep (), loopNestClauseOps.loopSteps );
320
+ }
321
+
322
+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
323
+ }
324
+
325
+ mlir::omp::LoopNestOp
326
+ genWsLoopOp (mlir::ConversionPatternRewriter &rewriter, fir::DoLoopOp doLoop,
327
+ mlir::IRMapping &mapper,
328
+ const mlir::omp::LoopNestOperands &clauseOps,
329
+ bool isComposite) const {
330
+
331
+ auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
332
+ wsloopOp.setComposite (isComposite);
333
+ rewriter.createBlock (&wsloopOp.getRegion ());
334
+
335
+ auto loopNestOp =
336
+ rewriter.create <mlir::omp::LoopNestOp>(doLoop.getLoc (), clauseOps);
337
+
338
+ // Clone the loop's body inside the loop nest construct using the
339
+ // mapped values.
340
+ rewriter.cloneRegionBefore (doLoop.getRegion (), loopNestOp.getRegion (),
341
+ loopNestOp.getRegion ().begin (), mapper);
342
+
343
+ mlir::Operation *terminator = loopNestOp.getRegion ().back ().getTerminator ();
344
+ rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
345
+ rewriter.create <mlir::omp::YieldOp>(terminator->getLoc ());
346
+ rewriter.eraseOp (terminator);
347
+
348
+ return loopNestOp;
349
+ }
350
+
174
351
bool mapToDevice;
352
+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip;
175
353
};
176
354
177
355
class DoConcurrentConversionPass
@@ -200,24 +378,24 @@ class DoConcurrentConversionPass
200
378
return ;
201
379
}
202
380
381
+ llvm::DenseSet<fir::DoLoopOp> concurrentLoopsToSkip;
203
382
mlir::RewritePatternSet patterns (context);
204
383
patterns.insert <DoConcurrentConversion>(
205
- context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device);
384
+ context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
385
+ concurrentLoopsToSkip);
206
386
mlir::ConversionTarget target (*context);
207
387
target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
208
388
// The goal is to handle constructs that eventually get lowered to
209
389
// `fir.do_loop` with the `unordered` attribute (e.g. array expressions).
210
390
// Currently, this is only enabled for the `do concurrent` construct since
211
391
// the pass runs early in the pipeline.
212
- return !op.getUnordered ();
392
+ return !op.getUnordered () || concurrentLoopsToSkip. contains (op) ;
213
393
});
214
394
target.markUnknownOpDynamicallyLegal (
215
395
[](mlir::Operation *) { return true ; });
216
396
217
397
if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
218
398
std::move (patterns)))) {
219
- mlir::emitError (mlir::UnknownLoc::get (context),
220
- " error in converting do-concurrent op" );
221
399
signalPassFailure ();
222
400
}
223
401
}
0 commit comments