@@ -316,39 +316,64 @@ void collectIndirectConstOpChain(mlir::Operation *link,
316316}
317317
318318// / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
319- // / there are no operations in \p outerloop's other than:
319+ // / there are no operations in \p outerloop's body other than:
320320// /
321- // / 1. the operations needed to assing /update \p outerLoop's induction variable.
321+ // / 1. the operations needed to assign /update \p outerLoop's induction variable.
322322// / 2. \p innerLoop itself.
323323// /
324324// / \p return true if \p innerLoop is perfectly nested inside \p outerLoop
325325// / according to the above definition.
326326bool isPerfectlyNested (fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
327- mlir::BackwardSliceOptions backwardSliceOptions;
328- backwardSliceOptions.inclusive = true ;
329- // We will collect the backward slices for innerLoop's LB, UB, and step.
330- // However, we want to limit the scope of these slices to the scope of
331- // outerLoop's region.
332- backwardSliceOptions.filter = [&](mlir::Operation *op) {
333- return !mlir::areValuesDefinedAbove (op->getResults (),
334- outerLoop.getRegion ());
335- };
336-
337327 mlir::ForwardSliceOptions forwardSliceOptions;
338328 forwardSliceOptions.inclusive = true ;
329+ // The following will be used as an example to clarify the internals of this
330+ // function:
331+ // ```
332+ // 1. fir.do_loop %i_idx = %34 to %36 step %c1 unordered {
333+ // 2. %i_idx_2 = fir.convert %i_idx : (index) -> i32
334+ // 3. fir.store %i_idx_2 to %i_iv#1 : !fir.ref<i32>
335+ //
336+ // 4. fir.do_loop %j_idx = %37 to %39 step %c1_3 unordered {
337+ // 5. %j_idx_2 = fir.convert %j_idx : (index) -> i32
338+ // 6. fir.store %j_idx_2 to %j_iv#1 : !fir.ref<i32>
339+ // ... loop nest body, possible uses %i_idx ...
340+ // }
341+ // }
342+ // ```
343+ // In this example, the `j` loop is perfectly nested inside the `i` loop and
344+ // below is how we find that.
345+
339346 // We don't care about the outer-loop's induction variable's uses within the
340347 // inner-loop, so we filter out these uses.
348+ //
349+ // This filter tells `getForwardSlice` (below) to only collect operations
350+ // which produce results defined above (i.e. outside) the inner-loop's body.
351+ //
352+ // Since `outerLoop.getInductionVar()` is a block argument (to the
353+ // outer-loop's body), the filter effectively collects uses of
354+ // `outerLoop.getInductionVar()` inside the outer-loop but outside the
355+ // inner-loop.
341356 forwardSliceOptions.filter = [&](mlir::Operation *op) {
342357 return mlir::areValuesDefinedAbove (op->getResults (), innerLoop.getRegion ());
343358 };
344359
345360 llvm::SetVector<mlir::Operation *> indVarSlice;
361+ // The forward slice of the `i` loop's IV will be the 2 ops in line 1 & 2
362+ // above. Uses of `%i_idx` inside the `j` loop are not collected because of
363+ // the filter.
346364 mlir::getForwardSlice (outerLoop.getInductionVar (), &indVarSlice,
347365 forwardSliceOptions);
348- llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet (indVarSlice.begin (),
349- indVarSlice.end ());
350-
351- llvm::DenseSet<mlir::Operation *> loopBodySet;
366+ llvm::DenseSet<mlir::Operation *> indVarSet (indVarSlice.begin (),
367+ indVarSlice.end ());
368+
369+ llvm::DenseSet<mlir::Operation *> outerLoopBodySet;
370+ // The following walk collects ops inside `outerLoop` that are **not**:
371+ // * the outer-loop itself,
372+ // * or the inner-loop,
373+ // * or the `fir.result` op (the outer-loop's terminator).
374+ //
375+ // For the above example, this will also populate `outerLoopBodySet` with ops
376+ // in line 1 & 2 since we skip the `i` loop, the `j` loop, and the terminator.
352377 outerLoop.walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
353378 if (op == outerLoop)
354379 return mlir::WalkResult::advance ();
@@ -359,43 +384,48 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
359384 if (mlir::isa<fir::ResultOp>(op))
360385 return mlir::WalkResult::advance ();
361386
362- loopBodySet .insert (op);
387+ outerLoopBodySet .insert (op);
363388 return mlir::WalkResult::advance ();
364389 });
365390
366- bool result = (loopBodySet == innerLoopSetupOpsSet);
391+ // If `outerLoopBodySet` ends up having the same ops as `indVarSet`, then
392+ // `outerLoop` only contains ops that setup its induction variable +
393+ // `innerLoop` + the `fir.result` terminator. In other words, `innerLoop` is
394+ // perfectly nested inside `outerLoop`.
395+ bool result = (outerLoopBodySet == indVarSet);
367396 mlir::Location loc = outerLoop.getLoc ();
368397 LLVM_DEBUG (DBGS () << " Loop pair starting at location " << loc << " is"
369398 << (result ? " " : " not" ) << " perfectly nested\n " );
370399
371400 return result;
372401}
373402
374- // / Starting with `outerLoop ` collect a perfectly nested loop nest, if any. This
375- // / function collects as much as possible loops in the nest; it case it fails to
376- // / recognize a certain nested loop as part of the nest it just returns the
377- // / parent loops it discovered before.
403+ // / Starting with `currentLoop ` collect a perfectly nested loop nest, if any.
404+ // / This function collects as much as possible loops in the nest; it case it
405+ // / fails to recognize a certain nested loop as part of the nest it just returns
406+ // / the parent loops it discovered before.
378407mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop,
379408 LoopNestToIndVarMap &loopNest) {
380409 assert (currentLoop.getUnordered ());
381410
382411 while (true ) {
383- loopNest.try_emplace (
384- currentLoop,
385- InductionVariableInfo{
386- findLoopIndVarMemDecl (currentLoop),
387- std::move (looputils::extractIndVarUpdateOps (currentLoop))});
388-
389- auto directlyNestedLoops = currentLoop.getRegion ().getOps <fir::DoLoopOp>();
412+ loopNest.insert (
413+ {currentLoop,
414+ InductionVariableInfo{
415+ findLoopIndVarMemDecl (currentLoop),
416+ std::move (looputils::extractIndVarUpdateOps (currentLoop))}});
390417 llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
391418
392- for (auto nestedLoop : directlyNestedLoops )
419+ for (auto nestedLoop : currentLoop. getRegion (). getOps <fir::DoLoopOp>() )
393420 if (nestedLoop.getUnordered ())
394421 unorderedLoops.push_back (nestedLoop);
395422
396423 if (unorderedLoops.empty ())
397424 break ;
398425
426+ // Having more than one unordered loop means that we are not dealing with a
427+ // perfect loop nest (i.e. a mulit-range `do concurrent` loop); which is the
428+ // case we are after here.
399429 if (unorderedLoops.size () > 1 )
400430 return mlir::failure ();
401431
0 commit comments