@@ -336,6 +336,64 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
336
336
++idx;
337
337
}
338
338
}
339
+
340
+ // / Collects values that are local to a loop: "loop-local values". A loop-local
341
+ // / value is one that is used exclusively inside the loop but allocated outside
342
+ // / of it. This usually corresponds to temporary values that are used inside the
343
+ // / loop body for initialzing other variables for example.
344
+ // /
345
+ // / See `flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90` for an
346
+ // / example of why we need this.
347
+ // /
348
+ // / \param [in] doLoop - the loop within which the function searches for values
349
+ // / used exclusively inside.
350
+ // /
351
+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
352
+ void collectLoopLocalValues (fir::DoLoopOp doLoop,
353
+ llvm::SetVector<mlir::Value> &locals) {
354
+ doLoop.walk ([&](mlir::Operation *op) {
355
+ for (mlir::Value operand : op->getOperands ()) {
356
+ if (locals.contains (operand))
357
+ continue ;
358
+
359
+ bool isLocal = true ;
360
+
361
+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
362
+ continue ;
363
+
364
+ // Values defined inside the loop are not interesting since they do not
365
+ // need to be localized.
366
+ if (doLoop->isAncestor (operand.getDefiningOp ()))
367
+ continue ;
368
+
369
+ for (auto *user : operand.getUsers ()) {
370
+ if (!doLoop->isAncestor (user)) {
371
+ isLocal = false ;
372
+ break ;
373
+ }
374
+ }
375
+
376
+ if (isLocal)
377
+ locals.insert (operand);
378
+ }
379
+ });
380
+ }
381
+
382
+ // / For a "loop-local" value \p local within a loop's scope, localizes that
383
+ // / value within the scope of the parallel region the loop maps to. Towards that
384
+ // / end, this function moves the allocation of \p local within \p allocRegion.
385
+ // /
386
+ // / \param local - the value used exclusively within a loop's scope (see
387
+ // / collectLoopLocalValues).
388
+ // /
389
+ // / \param allocRegion - the parallel region where \p local's allocation will be
390
+ // / privatized.
391
+ // /
392
+ // / \param rewriter - builder used for updating \p allocRegion.
393
+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
394
+ mlir::ConversionPatternRewriter &rewriter) {
395
+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
396
+ }
339
397
} // namespace looputils
340
398
341
399
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -358,13 +416,21 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
358
416
" Some `do concurent` loops are not perfectly-nested. "
359
417
" These will be serialzied." );
360
418
419
+ llvm::SetVector<mlir::Value> locals;
420
+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
361
421
looputils::sinkLoopIVArgs (rewriter, loopNest);
422
+
362
423
mlir::IRMapping mapper;
363
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
424
+ mlir::omp::ParallelOp parallelOp =
425
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
364
426
mlir::omp::LoopNestOperands loopNestClauseOps;
365
427
genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
366
428
loopNestClauseOps);
367
429
430
+ for (mlir::Value local : locals)
431
+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
432
+ rewriter);
433
+
368
434
mlir::omp::LoopNestOp ompLoopNest =
369
435
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
370
436
/* isComposite=*/ mapToDevice);
0 commit comments