@@ -313,6 +313,64 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
313313 ++idx;
314314 }
315315}
316+
317+ // / Collects values that are local to a loop: "loop-local values". A loop-local
318+ // / value is one that is used exclusively inside the loop but allocated outside
319+ // / of it. This usually corresponds to temporary values that are used inside the
320+ // / loop body for initialzing other variables for example.
321+ // /
322+ // / See `flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90` for an
323+ // / example of why we need this.
324+ // /
325+ // / \param [in] doLoop - the loop within which the function searches for values
326+ // / used exclusively inside.
327+ // /
328+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
329+ void collectLoopLocalValues (fir::DoLoopOp doLoop,
330+ llvm::SetVector<mlir::Value> &locals) {
331+ doLoop.walk ([&](mlir::Operation *op) {
332+ for (mlir::Value operand : op->getOperands ()) {
333+ if (locals.contains (operand))
334+ continue ;
335+
336+ bool isLocal = true ;
337+
338+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
339+ continue ;
340+
341+ // Values defined inside the loop are not interesting since they do not
342+ // need to be localized.
343+ if (doLoop->isAncestor (operand.getDefiningOp ()))
344+ continue ;
345+
346+ for (auto *user : operand.getUsers ()) {
347+ if (!doLoop->isAncestor (user)) {
348+ isLocal = false ;
349+ break ;
350+ }
351+ }
352+
353+ if (isLocal)
354+ locals.insert (operand);
355+ }
356+ });
357+ }
358+
359+ // / For a "loop-local" value \p local within a loop's scope, localizes that
360+ // / value within the scope of the parallel region the loop maps to. Towards that
361+ // / end, this function moves the allocation of \p local within \p allocRegion.
362+ // /
363+ // / \param local - the value used exclusively within a loop's scope (see
364+ // / collectLoopLocalValues).
365+ // /
366+ // / \param allocRegion - the parallel region where \p local's allocation will be
367+ // / privatized.
368+ // /
369+ // / \param rewriter - builder used for updating \p allocRegion.
370+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
371+ mlir::ConversionPatternRewriter &rewriter) {
372+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
373+ }
316374} // namespace looputils
317375
318376class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -339,13 +397,21 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
339397 " Some `do concurent` loops are not perfectly-nested. "
340398 " These will be serialized." );
341399
400+ llvm::SetVector<mlir::Value> locals;
401+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
342402 looputils::sinkLoopIVArgs (rewriter, loopNest);
403+
343404 mlir::IRMapping mapper;
344- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
405+ mlir::omp::ParallelOp parallelOp =
406+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
345407 mlir::omp::LoopNestOperands loopNestClauseOps;
346408 genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
347409 loopNestClauseOps);
348410
411+ for (mlir::Value local : locals)
412+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
413+ rewriter);
414+
349415 mlir::omp::LoopNestOp ompLoopNest =
350416 genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
351417 /* isComposite=*/ mapToDevice);
0 commit comments