@@ -329,6 +329,64 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
329
329
++idx;
330
330
}
331
331
}
332
+
333
+ // / Collects values that are local to a loop: "loop-local values". A loop-local
334
+ // / value is one that is used exclusively inside the loop but allocated outside
335
+ // / of it. This usually corresponds to temporary values that are used inside the
336
+ // / loop body for initialzing other variables for example.
337
+ // /
338
+ // / See `flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90` for an
339
+ // / example of why we need this.
340
+ // /
341
+ // / \param [in] doLoop - the loop within which the function searches for values
342
+ // / used exclusively inside.
343
+ // /
344
+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
345
+ void collectLoopLocalValues (fir::DoLoopOp doLoop,
346
+ llvm::SetVector<mlir::Value> &locals) {
347
+ doLoop.walk ([&](mlir::Operation *op) {
348
+ for (mlir::Value operand : op->getOperands ()) {
349
+ if (locals.contains (operand))
350
+ continue ;
351
+
352
+ bool isLocal = true ;
353
+
354
+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
355
+ continue ;
356
+
357
+ // Values defined inside the loop are not interesting since they do not
358
+ // need to be localized.
359
+ if (doLoop->isAncestor (operand.getDefiningOp ()))
360
+ continue ;
361
+
362
+ for (auto *user : operand.getUsers ()) {
363
+ if (!doLoop->isAncestor (user)) {
364
+ isLocal = false ;
365
+ break ;
366
+ }
367
+ }
368
+
369
+ if (isLocal)
370
+ locals.insert (operand);
371
+ }
372
+ });
373
+ }
374
+
375
+ // / For a "loop-local" value \p local within a loop's scope, localizes that
376
+ // / value within the scope of the parallel region the loop maps to. Towards that
377
+ // / end, this function moves the allocation of \p local within \p allocRegion.
378
+ // /
379
+ // / \param local - the value used exclusively within a loop's scope (see
380
+ // / collectLoopLocalValues).
381
+ // /
382
+ // / \param allocRegion - the parallel region where \p local's allocation will be
383
+ // / privatized.
384
+ // /
385
+ // / \param rewriter - builder used for updating \p allocRegion.
386
+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
387
+ mlir::ConversionPatternRewriter &rewriter) {
388
+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
389
+ }
332
390
} // namespace looputils
333
391
334
392
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -355,13 +413,21 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
355
413
" Some `do concurent` loops are not perfectly-nested. "
356
414
" These will be serialized." );
357
415
416
+ llvm::SetVector<mlir::Value> locals;
417
+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
358
418
looputils::sinkLoopIVArgs (rewriter, loopNest);
419
+
359
420
mlir::IRMapping mapper;
360
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
421
+ mlir::omp::ParallelOp parallelOp =
422
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
361
423
mlir::omp::LoopNestOperands loopNestClauseOps;
362
424
genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
363
425
loopNestClauseOps);
364
426
427
+ for (mlir::Value local : locals)
428
+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
429
+ rewriter);
430
+
365
431
mlir::omp::LoopNestOp ompLoopNest =
366
432
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
367
433
/* isComposite=*/ mapToDevice);
0 commit comments