1111#include " flang/Optimizer/OpenMP/Utils.h"
1212#include " mlir/Analysis/SliceAnalysis.h"
1313#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
14+ #include " mlir/IR/IRMapping.h"
1415#include " mlir/Transforms/DialectConversion.h"
1516#include " mlir/Transforms/RegionUtils.h"
1617
@@ -24,7 +25,67 @@ namespace flangomp {
2425
2526namespace {
2627namespace looputils {
27- using LoopNest = llvm::SetVector<fir::DoLoopOp>;
28+ // / Stores info needed about the induction/iteration variable for each `do
29+ // / concurrent` in a loop nest.
30+ struct InductionVariableInfo {
31+ // / The operation allocating memory for iteration variable.
32+ mlir::Operation *iterVarMemDef;
33+ };
34+
35+ using LoopNestToIndVarMap =
36+ llvm::MapVector<fir::DoLoopOp, InductionVariableInfo>;
37+
38+ // / For the \p doLoop parameter, find the operation that declares its iteration
39+ // / variable or allocates memory for it.
40+ // /
41+ // / For example, give the following loop:
42+ // / ```
43+ // / ...
44+ // / %i:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : ...
45+ // / ...
46+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
47+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
48+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
49+ // / ...
50+ // / }
51+ // / ```
52+ // /
53+ // / This function returns the `hlfir.declare` op for `%i`.
54+ // /
55+ // / Note: The current implementation is dependent on how flang emits loop
56+ // / bodies; which is sufficient for the current simple test/use cases. If this
57+ // / proves to be insufficient, this should be made more generic.
58+ mlir::Operation *findLoopIterationVarMemDecl (fir::DoLoopOp doLoop) {
59+ mlir::Value result = nullptr ;
60+
61+ // Checks if a StoreOp is updating the memref of the loop's iteration
62+ // variable.
63+ auto isStoringIV = [&](fir::StoreOp storeOp) {
64+ // Direct store into the IV memref.
65+ if (storeOp.getValue () == doLoop.getInductionVar ())
66+ return true ;
67+
68+ // Indirect store into the IV memref.
69+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(
70+ storeOp.getValue ().getDefiningOp ())) {
71+ if (convertOp.getOperand () == doLoop.getInductionVar ())
72+ return true ;
73+ }
74+
75+ return false ;
76+ };
77+
78+ for (mlir::Operation &op : doLoop) {
79+ if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(op))
80+ if (isStoringIV (storeOp)) {
81+ result = storeOp.getMemref ();
82+ break ;
83+ }
84+ }
85+
86+ assert (result != nullptr && result.getDefiningOp () != nullptr );
87+ return result.getDefiningOp ();
88+ }
2889
2990// / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
3091// / there are no operations in \p outerloop's body other than:
@@ -116,11 +177,14 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
116177// / fails to recognize a certain nested loop as part of the nest it just returns
117178// / the parent loops it discovered before.
118179mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop,
119- LoopNest &loopNest) {
180+ LoopNestToIndVarMap &loopNest) {
120181 assert (currentLoop.getUnordered ());
121182
122183 while (true ) {
123- loopNest.insert (currentLoop);
184+ loopNest.insert (
185+ {currentLoop,
186+ InductionVariableInfo{findLoopIterationVarMemDecl (currentLoop)}});
187+
124188 llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
125189
126190 for (auto nestedLoop : currentLoop.getRegion ().getOps <fir::DoLoopOp>())
@@ -152,26 +216,140 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
152216public:
153217 using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern;
154218
155- DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice)
156- : OpConversionPattern(context), mapToDevice(mapToDevice) {}
219+ DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice,
220+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip)
221+ : OpConversionPattern(context), mapToDevice(mapToDevice),
222+ concurrentLoopsToSkip (concurrentLoopsToSkip) {}
157223
158224 mlir::LogicalResult
159225 matchAndRewrite (fir::DoLoopOp doLoop, OpAdaptor adaptor,
160226 mlir::ConversionPatternRewriter &rewriter) const override {
161- looputils::LoopNest loopNest;
227+ if (mapToDevice)
228+ return doLoop.emitError (
229+ " not yet implemented: Mapping `do concurrent` loops to device" );
230+
231+ looputils::LoopNestToIndVarMap loopNest;
162232 bool hasRemainingNestedLoops =
163233 failed (looputils::collectLoopNest (doLoop, loopNest));
164234 if (hasRemainingNestedLoops)
165235 mlir::emitWarning (doLoop.getLoc (),
166236 " Some `do concurent` loops are not perfectly-nested. "
167237 " These will be serialized." );
168238
169- // TODO This will be filled in with the next PRs that upstreams the rest of
170- // the ROCm implementaion.
239+ mlir::IRMapping mapper;
240+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
241+ mlir::omp::LoopNestOperands loopNestClauseOps;
242+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
243+ loopNestClauseOps);
244+
245+ mlir::omp::LoopNestOp ompLoopNest =
246+ genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
247+ /* isComposite=*/ mapToDevice);
248+
249+ rewriter.eraseOp (doLoop);
250+
251+ // Mark `unordered` loops that are not perfectly nested to be skipped from
252+ // the legality check of the `ConversionTarget` since we are not interested
253+ // in mapping them to OpenMP.
254+ ompLoopNest->walk ([&](fir::DoLoopOp doLoop) {
255+ if (doLoop.getUnordered ()) {
256+ concurrentLoopsToSkip.insert (doLoop);
257+ }
258+ });
259+
171260 return mlir::success ();
172261 }
173262
263+ private:
264+ mlir::omp::ParallelOp genParallelOp (mlir::Location loc,
265+ mlir::ConversionPatternRewriter &rewriter,
266+ looputils::LoopNestToIndVarMap &loopNest,
267+ mlir::IRMapping &mapper) const {
268+ auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
269+ rewriter.createBlock (¶llelOp.getRegion ());
270+ rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
271+
272+ genLoopNestIndVarAllocs (rewriter, loopNest, mapper);
273+ return parallelOp;
274+ }
275+
276+ void genLoopNestIndVarAllocs (mlir::ConversionPatternRewriter &rewriter,
277+ looputils::LoopNestToIndVarMap &loopNest,
278+ mlir::IRMapping &mapper) const {
279+
280+ for (auto &[_, indVarInfo] : loopNest)
281+ genInductionVariableAlloc (rewriter, indVarInfo.iterVarMemDef , mapper);
282+ }
283+
284+ mlir::Operation *
285+ genInductionVariableAlloc (mlir::ConversionPatternRewriter &rewriter,
286+ mlir::Operation *indVarMemDef,
287+ mlir::IRMapping &mapper) const {
288+ assert (
289+ indVarMemDef != nullptr &&
290+ " Induction variable memdef is expected to have a defining operation." );
291+
292+ llvm::SmallSetVector<mlir::Operation *, 2 > indVarDeclareAndAlloc;
293+ for (auto operand : indVarMemDef->getOperands ())
294+ indVarDeclareAndAlloc.insert (operand.getDefiningOp ());
295+ indVarDeclareAndAlloc.insert (indVarMemDef);
296+
297+ mlir::Operation *result;
298+ for (mlir::Operation *opToClone : indVarDeclareAndAlloc)
299+ result = rewriter.clone (*opToClone, mapper);
300+
301+ return result;
302+ }
303+
304+ void genLoopNestClauseOps (
305+ mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
306+ looputils::LoopNestToIndVarMap &loopNest, mlir::IRMapping &mapper,
307+ mlir::omp::LoopNestOperands &loopNestClauseOps) const {
308+ assert (loopNestClauseOps.loopLowerBounds .empty () &&
309+ " Loop nest bounds were already emitted!" );
310+
311+ auto populateBounds = [](mlir::Value var,
312+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
313+ bounds.push_back (var.getDefiningOp ()->getResult (0 ));
314+ };
315+
316+ for (auto &[doLoop, _] : loopNest) {
317+ populateBounds (doLoop.getLowerBound (), loopNestClauseOps.loopLowerBounds );
318+ populateBounds (doLoop.getUpperBound (), loopNestClauseOps.loopUpperBounds );
319+ populateBounds (doLoop.getStep (), loopNestClauseOps.loopSteps );
320+ }
321+
322+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
323+ }
324+
325+ mlir::omp::LoopNestOp
326+ genWsLoopOp (mlir::ConversionPatternRewriter &rewriter, fir::DoLoopOp doLoop,
327+ mlir::IRMapping &mapper,
328+ const mlir::omp::LoopNestOperands &clauseOps,
329+ bool isComposite) const {
330+
331+ auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
332+ wsloopOp.setComposite (isComposite);
333+ rewriter.createBlock (&wsloopOp.getRegion ());
334+
335+ auto loopNestOp =
336+ rewriter.create <mlir::omp::LoopNestOp>(doLoop.getLoc (), clauseOps);
337+
338+ // Clone the loop's body inside the loop nest construct using the
339+ // mapped values.
340+ rewriter.cloneRegionBefore (doLoop.getRegion (), loopNestOp.getRegion (),
341+ loopNestOp.getRegion ().begin (), mapper);
342+
343+ mlir::Operation *terminator = loopNestOp.getRegion ().back ().getTerminator ();
344+ rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
345+ rewriter.create <mlir::omp::YieldOp>(terminator->getLoc ());
346+ rewriter.eraseOp (terminator);
347+
348+ return loopNestOp;
349+ }
350+
174351 bool mapToDevice;
352+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip;
175353};
176354
177355class DoConcurrentConversionPass
@@ -200,24 +378,24 @@ class DoConcurrentConversionPass
200378 return ;
201379 }
202380
381+ llvm::DenseSet<fir::DoLoopOp> concurrentLoopsToSkip;
203382 mlir::RewritePatternSet patterns (context);
204383 patterns.insert <DoConcurrentConversion>(
205- context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device);
384+ context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
385+ concurrentLoopsToSkip);
206386 mlir::ConversionTarget target (*context);
207387 target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
208388 // The goal is to handle constructs that eventually get lowered to
209389 // `fir.do_loop` with the `unordered` attribute (e.g. array expressions).
210390 // Currently, this is only enabled for the `do concurrent` construct since
211391 // the pass runs early in the pipeline.
212- return !op.getUnordered ();
392+ return !op.getUnordered () || concurrentLoopsToSkip. contains (op) ;
213393 });
214394 target.markUnknownOpDynamicallyLegal (
215395 [](mlir::Operation *) { return true ; });
216396
217397 if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
218398 std::move (patterns)))) {
219- mlir::emitError (mlir::UnknownLoc::get (context),
220- " error in converting do-concurrent op" );
221399 signalPassFailure ();
222400 }
223401 }
0 commit comments