@@ -313,6 +313,64 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
313
313
++idx;
314
314
}
315
315
}
316
+
317
+ // / Collects values that are local to a loop: "loop-local values". A loop-local
318
+ // / value is one that is used exclusively inside the loop but allocated outside
319
+ // / of it. This usually corresponds to temporary values that are used inside the
320
+ // / loop body for initialzing other variables for example.
321
+ // /
322
+ // / See `flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90` for an
323
+ // / example of why we need this.
324
+ // /
325
+ // / \param [in] doLoop - the loop within which the function searches for values
326
+ // / used exclusively inside.
327
+ // /
328
+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
329
+ void collectLoopLocalValues (fir::DoLoopOp doLoop,
330
+ llvm::SetVector<mlir::Value> &locals) {
331
+ doLoop.walk ([&](mlir::Operation *op) {
332
+ for (mlir::Value operand : op->getOperands ()) {
333
+ if (locals.contains (operand))
334
+ continue ;
335
+
336
+ bool isLocal = true ;
337
+
338
+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
339
+ continue ;
340
+
341
+ // Values defined inside the loop are not interesting since they do not
342
+ // need to be localized.
343
+ if (doLoop->isAncestor (operand.getDefiningOp ()))
344
+ continue ;
345
+
346
+ for (auto *user : operand.getUsers ()) {
347
+ if (!doLoop->isAncestor (user)) {
348
+ isLocal = false ;
349
+ break ;
350
+ }
351
+ }
352
+
353
+ if (isLocal)
354
+ locals.insert (operand);
355
+ }
356
+ });
357
+ }
358
+
359
+ // / For a "loop-local" value \p local within a loop's scope, localizes that
360
+ // / value within the scope of the parallel region the loop maps to. Towards that
361
+ // / end, this function moves the allocation of \p local within \p allocRegion.
362
+ // /
363
+ // / \param local - the value used exclusively within a loop's scope (see
364
+ // / collectLoopLocalValues).
365
+ // /
366
+ // / \param allocRegion - the parallel region where \p local's allocation will be
367
+ // / privatized.
368
+ // /
369
+ // / \param rewriter - builder used for updating \p allocRegion.
370
+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
371
+ mlir::ConversionPatternRewriter &rewriter) {
372
+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
373
+ }
316
374
} // namespace looputils
317
375
318
376
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -339,13 +397,21 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
339
397
" Some `do concurent` loops are not perfectly-nested. "
340
398
" These will be serialized." );
341
399
400
+ llvm::SetVector<mlir::Value> locals;
401
+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
342
402
looputils::sinkLoopIVArgs (rewriter, loopNest);
403
+
343
404
mlir::IRMapping mapper;
344
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
405
+ mlir::omp::ParallelOp parallelOp =
406
+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
345
407
mlir::omp::LoopNestOperands loopNestClauseOps;
346
408
genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
347
409
loopNestClauseOps);
348
410
411
+ for (mlir::Value local : locals)
412
+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
413
+ rewriter);
414
+
349
415
mlir::omp::LoopNestOp ompLoopNest =
350
416
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
351
417
/* isComposite=*/ mapToDevice);
0 commit comments