@@ -28,7 +28,7 @@ namespace looputils {
28
28
// / Stores info needed about the induction/iteration variable for each `do
29
29
// / concurrent` in a loop nest.
30
30
struct InductionVariableInfo {
31
- // / the operation allocating memory for iteration variable,
31
+ // / The operation allocating memory for iteration variable.
32
32
mlir::Operation *iterVarMemDef;
33
33
};
34
34
@@ -57,13 +57,30 @@ using LoopNestToIndVarMap =
57
57
// / proves to be insufficient, this should be made more generic.
58
58
mlir::Operation *findLoopIterationVarMemDecl (fir::DoLoopOp doLoop) {
59
59
mlir::Value result = nullptr ;
60
- for (mlir::Operation &op : doLoop) {
61
- // The first `fir.store` op we come across should be the op that updates the
62
- // loop's iteration variable.
63
- if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(op)) {
64
- result = storeOp.getMemref ();
65
- break ;
60
+
61
+ // Checks if a StoreOp is updating the memref of the loop's iteration
62
+ // variable.
63
+ auto isStoringIV = [&](fir::StoreOp storeOp) {
64
+ // Direct store into the IV memref.
65
+ if (storeOp.getValue () == doLoop.getInductionVar ())
66
+ return true ;
67
+
68
+ // Indirect store into the IV memref.
69
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(
70
+ storeOp.getValue ().getDefiningOp ())) {
71
+ if (convertOp.getOperand () == doLoop.getInductionVar ())
72
+ return true ;
66
73
}
74
+
75
+ return false ;
76
+ };
77
+
78
+ for (mlir::Operation &op : doLoop) {
79
+ if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(op))
80
+ if (isStoringIV (storeOp)) {
81
+ result = storeOp.getMemref ();
82
+ break ;
83
+ }
67
84
}
68
85
69
86
assert (result != nullptr && result.getDefiningOp () != nullptr );
@@ -291,8 +308,8 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
291
308
assert (loopNestClauseOps.loopLowerBounds .empty () &&
292
309
" Loop nest bounds were already emitted!" );
293
310
294
- auto populateBounds = [& ](mlir::Value var,
295
- llvm::SmallVectorImpl<mlir::Value> &bounds) {
311
+ auto populateBounds = [](mlir::Value var,
312
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
296
313
bounds.push_back (var.getDefiningOp ()->getResult (0 ));
297
314
};
298
315
0 commit comments