Skip to content

[Flang][OpenMP] Create MLIR optimization pass to push index allocations into loop body and remove them if redundant #67010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Sep 21, 2023

This patch adds a Flang-only MLIR optimization pass that aims to remove redundant allocations of loop index variables related to OpenMP loops and improve LLVM IR code generation. The loop operations for which this is implemented and tested are omp.wsloop and omp.simdloop, and the main ways in which this is approached are to move allocations inside of the loop body (later this avoids having to pass these variables as arguments to an outlined function in LLVM IR) and to use block arguments representing loop indices on the loop region directly instead, if possible.

This is done in two stages:

  1. Push allocations (fir.alloca and hlfir.declare) inside of the loop operation's region. This is only done for allocations that are used to store loop index variables and only used inside of a single loop region. The result of this is that, during MLIR to LLVM IR translation, when the loop operation is lowered by creating a function the allocation does not need to be passed as an additional argument.
  2. Remove allocations and related load and store operations, and access the index through the corresponding block argument. If the previous step is successful, this can also be done if all uses of the allocation are fir.load or fir.store, meaning that it's not passed as a reference to another function/subprocedure.

The pass has been implemented to work with and without HLFIR support enabled, and multiple unit tests have been updated due to this pass running by default.

…ns into loop body and remove them if redundant

This patch adds a Flang-only MLIR optimization pass that aims to remove
redundant allocations of loop index variables related to OpenMP loops and
improve LLVM IR code generation. The loop operations for which this is
implemented and tested are `omp.wsloop` and `omp.simdloop`, and the main ways
in which this is approached are to move allocations inside of the loop body
(later this avoids having to pass these variables as arguments to an outlined
function in LLVM IR) and to use block arguments representing loop indices on
the loop region directly instead, if possible.

This is done in two stages:
  1. Push allocations (`fir.alloca` and `hlfir.declare`) inside of the loop
operation's region. This is only done for allocations that are used to store
loop index variables and only used inside of a single loop region. The result
of this is that, during MLIR to LLVM IR translation, when the loop operation is
lowered by creating a function the allocation does not need to be passed as an
additional argument.
  2. Remove allocations and related load and store operations, and access the
index through the corresponding block argument. If the previous step is
successful, this can also be done if all uses of the allocation are `fir.load`
or `fir.store`, meaning that it's not passed as a reference to another
function/subprocedure.

The pass has been implemented to work with and without HLFIR support enabled,
and multiple unit tests have been updated due to this pass running by default.
@llvmbot
Copy link
Member

llvmbot commented Sep 21, 2023

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Changes

This patch adds a Flang-only MLIR optimization pass that aims to remove redundant allocations of loop index variables related to OpenMP loops and improve LLVM IR code generation. The loop operations for which this is implemented and tested are omp.wsloop and omp.simdloop, and the main ways in which this is approached are to move allocations inside of the loop body (later this avoids having to pass these variables as arguments to an outlined function in LLVM IR) and to use block arguments representing loop indices on the loop region directly instead, if possible.

This is done in two stages:

  1. Push allocations (fir.alloca and hlfir.declare) inside of the loop operation's region. This is only done for allocations that are used to store loop index variables and only used inside of a single loop region. The result of this is that, during MLIR to LLVM IR translation, when the loop operation is lowered by creating a function the allocation does not need to be passed as an additional argument.
  2. Remove allocations and related load and store operations, and access the index through the corresponding block argument. If the previous step is successful, this can also be done if all uses of the allocation are fir.load or fir.store, meaning that it's not passed as a reference to another function/subprocedure.

The pass has been implemented to work with and without HLFIR support enabled, and multiple unit tests have been updated due to this pass running by default.


Patch is 114.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67010.diff

32 Files Affected:

  • (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+9)
  • (modified) flang/include/flang/Tools/CLOptions.inc (+1)
  • (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/Transforms/OMPLoopIndexMemToReg.cpp (+250)
  • (modified) flang/test/Lower/OpenMP/FIR/copyin.f90 (+24-29)
  • (modified) flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 (-2)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 (+18-21)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause.f90 (+7-21)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop-firstpriv.f90 (+4-4)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop.f90 (+8-16)
  • (modified) flang/test/Lower/OpenMP/FIR/simd.f90 (+8-24)
  • (modified) flang/test/Lower/OpenMP/FIR/stop-stmt-in-region.f90 (-2)
  • (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+10-13)
  • (modified) flang/test/Lower/OpenMP/FIR/unstructured.f90 (+2-10)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-chunks.f90 (+3-9)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-collapse.f90 (+3-9)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-monotonic.f90 (+1-4)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-nonmonotonic.f90 (+1-4)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add.f90 (+14-42)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-and.f90 (+5-16)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv.f90 (+5-16)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv.f90 (+5-16)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-or.f90 (+5-16)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-mul.f90 (+14-42)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-simd.f90 (+1-3)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-variable.f90 (+1-3)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop.f90 (+3-12)
  • (modified) flang/test/Lower/OpenMP/Todo/omp-default-clause-inner-loop.f90 (+1-4)
  • (modified) flang/test/Lower/OpenMP/hlfir-wsloop.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 (+1-5)
  • (added) flang/test/Transforms/omp-wsloop-index.mlir (+247)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 8aeb3e373298e88..0a9a3ca5bd0309b 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -79,6 +79,7 @@ createOMPEarlyOutliningPass();
 std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 createOMPMarkDeclareTargetPass();
+std::unique_ptr<mlir::Pass> createOMPLoopIndexMemToRegPass();
 
 // declarative passes
 #define GEN_PASS_REGISTRATION
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 9474edf13ce4639..8304b882d525cfd 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -326,4 +326,13 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
   ];
 }
 
+def OMPLoopIndexMemToReg : Pass<"omp-loop-index-mem2reg", "mlir::func::FuncOp"> {
+  let summary = "Pushes allocations for index variables of OpenMP loops into "
+                "the loop region and, if they are never passed by reference, "
+                "they are replaced by the corresponding entry block arguments, "
+                "removing all redundant allocations in the process.";
+  let constructor = "::fir::createOMPLoopIndexMemToRegPass()";
+  let dependentDialects = ["fir::FIROpsDialect", "mlir::omp::OpenMPDialect"];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 616d9ddc066a75d..0b5e8a065680422 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -270,6 +270,7 @@ inline void createOpenMPFIRPassPipeline(
     pm.addPass(fir::createOMPEarlyOutliningPass());
     pm.addPass(fir::createOMPFunctionFilteringPass());
   }
+  pm.addPass(fir::createOMPLoopIndexMemToRegPass());
 }
 
 #if !defined(FLANG_EXCLUDE_CODEGEN)
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 3d2b7e5eaeade0a..306551b03ced1af 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_flang_library(FIRTransforms
   OMPEarlyOutlining.cpp
   OMPFunctionFiltering.cpp
   OMPMarkDeclareTarget.cpp
+  OMPLoopIndexMemToReg.cpp
 
   DEPENDS
   FIRDialect
diff --git a/flang/lib/Optimizer/Transforms/OMPLoopIndexMemToReg.cpp b/flang/lib/Optimizer/Transforms/OMPLoopIndexMemToReg.cpp
new file mode 100644
index 000000000000000..af117d625154bb4
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/OMPLoopIndexMemToReg.cpp
@@ -0,0 +1,250 @@
+//===- OMPWsLoopIndexMem2Reg.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to push allocations into an OpenMP loop
+// operation region when they are used to store loop indices. Then, they are
+// removed together with any associated load or store operations if their
+// address is not needed, in which case uses of their values are replaced for
+// the block argument from which they were originally initialized.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include <llvm/ADT/MapVector.h>
+#include <llvm/ADT/SmallSet.h>
+#include <llvm/ADT/SmallVector.h>
+#include <llvm/Support/Casting.h>
+#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
+#include <mlir/IR/Builders.h>
+#include <mlir/IR/Value.h>
+#include <mlir/IR/ValueRange.h>
+#include <mlir/Support/LLVM.h>
+
+namespace fir {
+#define GEN_PASS_DEF_OMPLOOPINDEXMEMTOREG
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+template <typename LoopOpTy>
+class LoopProcessorHelper {
+  LoopOpTy loop;
+
+  bool allUsesInLoop(ValueRange stores) {
+    for (Value store : stores) {
+      for (OpOperand &use : store.getUses()) {
+        Operation *owner = use.getOwner();
+        if (owner->getParentOfType<LoopOpTy>() != loop.getOperation())
+          return false;
+      }
+    }
+    return true;
+  }
+
+  /// Check whether a given hlfir.declare known to only be used inside of the
+  /// loop and initialized by a fir.alloca operation also only used inside of
+  /// the loop can be removed and replaced by the block argument representing
+  /// the corresponding loop index.
+  static bool isDeclareRemovable(hlfir::DeclareOp declareOp) {
+    fir::AllocaOp allocaOp = llvm::dyn_cast_if_present<fir::AllocaOp>(
+        declareOp.getMemref().getDefiningOp());
+
+    // Check that the hlfir.declare is initialized by a fir.alloca that is only
+    // used as argument to that operation.
+    if (!allocaOp || !allocaOp.getResult().hasOneUse())
+      return false;
+
+    // Check that uses of the pointers can be replaced by the block argument.
+    for (OpOperand &use : declareOp.getOriginalBase().getUses()) {
+      Operation *owner = use.getOwner();
+      if (!isa<fir::StoreOp>(owner))
+        return false;
+    }
+    for (OpOperand &use : declareOp.getBase().getUses()) {
+      Operation *owner = use.getOwner();
+      if (!isa<fir::LoadOp>(owner))
+        return false;
+    }
+
+    return true;
+  }
+
+  /// Check whether a given fir.alloca known to only be used inside of the loop
+  /// can be removed and replaced by the block argument representing the
+  /// corresponding loop index.
+  static bool isAllocaRemovable(fir::AllocaOp allocaOp) {
+    // Check that uses of the pointer are all fir.load and fir.store.
+    for (OpOperand &use : allocaOp.getResult().getUses()) {
+      Operation *owner = use.getOwner();
+      if (!isa<fir::LoadOp>(owner) && !isa<fir::StoreOp>(owner))
+        return false;
+    }
+
+    return true;
+  }
+
+  /// Try to push an hlfir.declare operation defined outside of the loop inside,
+  /// if all uses of that operation and the corresponding fir.alloca are
+  /// contained inside of the loop.
+  LogicalResult pushDeclareIntoLoop(hlfir::DeclareOp declareOp) {
+    // Check that all uses are inside of the loop.
+    if (!allUsesInLoop(declareOp->getResults()))
+      return failure();
+
+    // Push hlfir.declare into the beginning of the loop region.
+    Block &b = loop.getRegion().getBlocks().front();
+    declareOp->moveBefore(&b, b.begin());
+
+    // Find associated fir.alloca and push into the beginning of the loop
+    // region.
+    fir::AllocaOp allocaOp =
+        cast<fir::AllocaOp>(declareOp.getMemref().getDefiningOp());
+    Value allocaVal = allocaOp.getResult();
+
+    if (!allUsesInLoop(allocaVal))
+      return failure();
+
+    allocaOp->moveBefore(&b, b.begin());
+    return success();
+  }
+
+  /// Try to push a fir.alloca operation defined outside of the loop inside,
+  /// if all uses of that operation are contained inside of the loop.
+  LogicalResult pushAllocaIntoLoop(fir::AllocaOp allocaOp) {
+    Value store = allocaOp.getResult();
+
+    // Check that all uses are inside of the loop.
+    if (!allUsesInLoop(store))
+      return failure();
+
+    // Push fir.alloca into the beginning of the loop region.
+    Block &b = loop.getRegion().getBlocks().front();
+    allocaOp->moveBefore(&b, b.begin());
+    return success();
+  }
+
+  void processLoopArg(BlockArgument arg, llvm::ArrayRef<Value> argStores,
+                      SmallPtrSetImpl<Operation *> &opsToDelete) {
+    llvm::SmallPtrSet<Operation *, 16> toDelete;
+    for (Value store : argStores) {
+      Operation *op = store.getDefiningOp();
+
+      // Skip argument if storage not defined by an operation.
+      if (!op)
+        return;
+
+      // Support HLFIR flow as well as regular FIR flow.
+      if (auto declareOp = dyn_cast<hlfir::DeclareOp>(op)) {
+        if (succeeded(pushDeclareIntoLoop(declareOp)) &&
+            isDeclareRemovable(declareOp)) {
+          // Mark hlfir.declare, fir.alloca and related uses for deletion.
+          for (OpOperand &use : declareOp.getOriginalBase().getUses())
+            toDelete.insert(use.getOwner());
+
+          for (OpOperand &use : declareOp.getBase().getUses())
+            toDelete.insert(use.getOwner());
+
+          Operation *allocaOp = declareOp.getMemref().getDefiningOp();
+          toDelete.insert(declareOp);
+          toDelete.insert(allocaOp);
+        }
+      } else if (auto allocaOp = dyn_cast<fir::AllocaOp>(op)) {
+        if (succeeded(pushAllocaIntoLoop(allocaOp)) &&
+            isAllocaRemovable(allocaOp)) {
+          // Do not make any further modifications if an address to the index
+          // is necessary. Otherwise, the values can be used directly from the
+          // loop region first block's arguments.
+
+          // Mark fir.alloca and related uses for deletion.
+          for (OpOperand &use : allocaOp.getResult().getUses())
+            toDelete.insert(use.getOwner());
+
+          // Delete now-unused fir.alloca.
+          toDelete.insert(allocaOp);
+        }
+      } else {
+        return;
+      }
+    }
+
+    // Only consider marked operations if all load, store and allocation
+    // operations associated with the given loop index can be removed.
+    opsToDelete.insert(toDelete.begin(), toDelete.end());
+
+    for (Operation *op : toDelete) {
+      // Replace all fir.load operations with the index as returned by the
+      // OpenMP loop operation.
+      if (isa<fir::LoadOp>(op))
+        op->replaceAllUsesWith(ValueRange(arg));
+      // Drop all uses of fir.alloca and hlfir.declare because their defining
+      // operations will be deleted as well.
+      else if (isa<fir::AllocaOp>(op) || isa<hlfir::DeclareOp>(op))
+        op->dropAllUses();
+    }
+  }
+
+public:
+  explicit LoopProcessorHelper(LoopOpTy loop) : loop(loop) {}
+
+  void process() {
+    llvm::SmallPtrSet<Operation *, 16> opsToDelete;
+    llvm::SmallVector<llvm::SmallVector<Value>> storeAddresses;
+    llvm::ArrayRef<BlockArgument> loopArgs = loop.getRegion().getArguments();
+
+    // Collect arguments of the loop operation.
+    for (BlockArgument arg : loopArgs) {
+      // Find fir.store uses of these indices and gather all addresses where
+      // they are stored.
+      llvm::SmallVector<Value> &argStores = storeAddresses.emplace_back();
+      for (OpOperand &argUse : arg.getUses())
+        if (auto storeOp = dyn_cast<fir::StoreOp>(argUse.getOwner()))
+          argStores.push_back(storeOp.getMemref());
+    }
+
+    // Process all loop indices and mark them for deletion independently of each
+    // other.
+    for (auto it : llvm::zip(loopArgs, storeAddresses))
+      processLoopArg(std::get<0>(it), std::get<1>(it), opsToDelete);
+
+    // Delete marked operations.
+    for (Operation *op : opsToDelete)
+      op->erase();
+  }
+};
+
+namespace {
+class OMPLoopIndexMemToRegPass
+    : public fir::impl::OMPLoopIndexMemToRegBase<OMPLoopIndexMemToRegPass> {
+public:
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    func->walk(
+        [&](omp::WsLoopOp loop) { LoopProcessorHelper(loop).process(); });
+
+    func.walk(
+        [&](omp::SimdLoopOp loop) { LoopProcessorHelper(loop).process(); });
+
+    func.walk(
+        [&](omp::TaskLoopOp loop) { LoopProcessorHelper(loop).process(); });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> fir::createOMPLoopIndexMemToRegPass() {
+  return std::make_unique<OMPLoopIndexMemToRegPass>();
+}
diff --git a/flang/test/Lower/OpenMP/FIR/copyin.f90 b/flang/test/Lower/OpenMP/FIR/copyin.f90
index ddfa0ea0914628f..3443b310074f5b0 100644
--- a/flang/test/Lower/OpenMP/FIR/copyin.f90
+++ b/flang/test/Lower/OpenMP/FIR/copyin.f90
@@ -138,17 +138,15 @@ subroutine copyin_derived_type()
 ! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_QFcombined_parallel_worksharing_loopEx6) : !fir.ref<i32>
 ! CHECK:         %[[VAL_2:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
 ! CHECK:         omp.parallel   {
-! CHECK:           %[[VAL_3:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
-! CHECK:           %[[VAL_4:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
-! CHECK:           %[[VAL_5:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
-! CHECK:           fir.store %[[VAL_5]] to %[[VAL_4]] : !fir.ref<i32>
+! CHECK:           %[[VAL_3:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
+! CHECK:           %[[VAL_4:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
+! CHECK:           fir.store %[[VAL_4]] to %[[VAL_3]] : !fir.ref<i32>
 ! CHECK:           omp.barrier
-! CHECK:           %[[VAL_6:.*]] = arith.constant 1 : i32
-! CHECK:           %[[VAL_7:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
-! CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i32
-! CHECK:           omp.wsloop   for  (%[[VAL_9:.*]]) : i32 = (%[[VAL_6]]) to (%[[VAL_7]]) inclusive step (%[[VAL_8]]) {
-! CHECK:             fir.store %[[VAL_9]] to %[[VAL_3]] : !fir.ref<i32>
-! CHECK:             fir.call @_QPsub4(%[[VAL_4]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i32
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+! CHECK:           %[[VAL_7:.*]] = arith.constant 1 : i32
+! CHECK:           omp.wsloop   for  (%[[VAL_9:.*]]) : i32 = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_7]]) {
+! CHECK:             fir.call @_QPsub4(%[[VAL_3]]) {{.*}}: (!fir.ref<i32>) -> ()
 ! CHECK:             omp.yield
 ! CHECK:           }
 ! CHECK:           omp.terminator
@@ -269,30 +267,27 @@ subroutine common_1()
 !CHECK: %[[val_7:.*]] = fir.coordinate_of %[[val_6]], %[[val_c4]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
 !CHECK: %[[val_8:.*]] = fir.convert %[[val_7]] : (!fir.ref<i8>) -> !fir.ref<i32>
 !CHECK: omp.parallel {
-!CHECK: %[[val_9:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
-!CHECK: %[[val_10:.*]] = omp.threadprivate %[[val_1]] : !fir.ref<!fir.array<8xi8>> -> !fir.ref<!fir.array<8xi8>>
-!CHECK: %[[val_11:.*]] = fir.convert %[[val_10]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
+!CHECK: %[[val_9:.*]] = omp.threadprivate %[[val_1]] : !fir.ref<!fir.array<8xi8>> -> !fir.ref<!fir.array<8xi8>>
+!CHECK: %[[val_10:.*]] = fir.convert %[[val_9]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
 !CHECK: %[[val_c0_0:.*]] = arith.constant 0 : index
-!CHECK: %[[val_12:.*]] = fir.coordinate_of %[[val_11]], %[[val_c0_0]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
-!CHECK: %[[val_13:.*]] = fir.convert %[[val_12]] : (!fir.ref<i8>) -> !fir.ref<i32>
-!CHECK: %[[val_14:.*]] = fir.convert %[[val_10]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
+!CHECK: %[[val_11:.*]] = fir.coordinate_of %[[val_10]], %[[val_c0_0]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
+!CHECK: %[[val_12:.*]] = fir.convert %[[val_11]] : (!fir.ref<i8>) -> !fir.ref<i32>
+!CHECK: %[[val_13:.*]] = fir.convert %[[val_9]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
 !CHECK: %[[val_c4_1:.*]] = arith.constant 4 : index
-!CHECK: %[[val_15:.*]] = fir.coordinate_of %[[val_14]], %[[val_c4_1]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
-!CHECK: %[[val_16:.*]] = fir.convert %[[val_15]] : (!fir.ref<i8>) -> !fir.ref<i32>
-!CHECK: %[[val_17:.*]] = fir.load %[[val_5]] : !fir.ref<i32>
-!CHECK: fir.store %[[val_17]] to %[[val_13]] : !fir.ref<i32>
-!CHECK: %[[val_18:.*]] = fir.load %[[val_8]] : !fir.ref<i32>
-!CHECK: fir.store %[[val_18]] to %[[val_16]] : !fir.ref<i32>
+!CHECK: %[[val_14:.*]] = fir.coordinate_of %[[val_13]], %[[val_c4_1]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
+!CHECK: %[[val_15:.*]] = fir.convert %[[val_14]] : (!fir.ref<i8>) -> !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = fir.load %[[val_5]] : !fir.ref<i32>
+!CHECK: fir.store %[[val_16]] to %[[val_12]] : !fir.ref<i32>
+!CHECK: %[[val_17:.*]] = fir.load %[[val_8]] : !fir.ref<i32>
+!CHECK: fir.store %[[val_17]] to %[[val_15]] : !fir.ref<i32>
 !CHECK: omp.barrier
 !CHECK: %[[val_c1_i32:.*]] = arith.constant 1 : i32
-!CHECK: %[[val_19:.*]] = fir.load %[[val_13]] : !fir.ref<i32>
+!CHECK: %[[val_18:.*]] = fir.load %[[val_12]] : !fir.ref<i32>
 !CHECK: %[[val_c1_i32_2:.*]] = arith.constant 1 : i32
-!CHECK: omp.wsloop   for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_19]]) inclusive step (%[[val_c1_i32_2]]) {
-!CHECK: fir.store %[[arg]] to %[[val_9]] : !fir.ref<i32>
-!CHECK: %[[val_20:.*]] = fir.load %[[val_16]] : !fir.ref<i32>
-!CHECK: %[[val_21:.*]] = fir.load %[[val_9]] : !fir.ref<i32>
-!CHECK: %[[val_22:.*]] = arith.addi %[[val_20]], %[[val_21]] : i32
-!CHECK: fir.store %[[val_22]] to %[[val_16]] : !fir.ref<i32>
+!CHECK: omp.wsloop   for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_18]]) inclusive step (%[[val_c1_i32_2]]) {
+!CHECK: %[[val_19:.*]] = fir.load %[[val_15]] : !fir.ref<i32>
+!CHECK: %[[val_20:.*]] = arith.addi %[[val_19]], %[[arg]] : i32
+!CHECK: fir.store %[[val_20]] to %[[val_15]] : !fir.ref<i32>
 !CHECK: omp.yield
 !CHECK: }
 !CHECK: omp.terminator
diff --git a/flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 b/flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90
index 06f3e1ca82234ee..bba9dbc4fc4cb8e 100644
--- a/flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90
+++ b/flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90
@@ -1,7 +1,6 @@
 ! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s 
 
 !CHECK: func.func @_QPlastprivate_common() {
-!CHECK: %[[val_0:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
 !CHECK: %[[val_1:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFlastprivate_commonEi"}
 !CHECK: %[[val_2:.*]] = fir.address_of(@c_) : !fir.ref<!fir.array<8xi8>>
 !CHECK: %[[val_3:.*]] = fir.convert %[[val_2]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
@@ -18,7 +17,6 @@
 !CHECK: %[[val_c100_i32:.*]] = arith.constant 100 : i32
 !CHECK: %[[val_c1_i32_0:.*]] = arith.constant 1 : i32
 !CHECK: omp.wsloop   for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_c100_i32]]) inclusive step (%[[val_c1_i32_0]]) {
-!CHECK: fir.store %[[arg]] to %[[val_0]] : !fir.ref<i32>
 !CHECK: %[[val_11:.*]] = arith.cmpi eq, %[[arg]], %[[val_c100_i32]] : i32
 !CHECK: fir.if %[[val_11]] {
 !CHECK: %[[val_12:.*]] = fir.load %[[val_9]] : !fir.ref<f32>
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 b/flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90
index 3152f9c44d0c64a..8cf216361bcb606 100644
--- a/flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90
+++ b/flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90
@@ -8,34 +8,31 @@
 ! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFmultiple_private_fixEx"}
 ! CHECK:         omp.parallel {
 ! CHECK:           %[[PRIV_J:.*]] = fir.alloca i32 {bindc_name = "j", pinned
-! CHECK:           %[[PRIV_I:.*]] = fir.alloca i32 {adapt.valuebyref, pinned
 ! CHECK:           %[[PRIV_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned
 ! CHECK:           %[[ONE:.*]] = arith.constant 1 : i32
 ! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_4:.*]] : !fir.ref<i32>
 ! CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i32
-! CHECK:           omp.wsloop for (%[[VAL_6:.*]]) : i32 = (%[[ONE]]) to (%[[VAL_3]]) inclusive step (%[[VAL_5]]) {
-! CHECK:             fir.store %[[VAL_6]] to %[[PRIV_I]] : !fir.ref<i32>
-! CHECK:             %[[VAL_7:.*]] = arith.constant 1 : i32
-! CHECK:             %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (i32) -> index
-! CHECK:             %[[VAL_9:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
-! CHECK:             %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i32) -> index
-! CHECK:             %[[VAL_11:.*]] = arith.constant 1 : index
-! CHECK:             %[[LB:.*]] = fir.convert %[[VAL_8]] : (index) -> i...
[truncated]

@kiranchandramohan
Copy link
Contributor

kiranchandramohan commented Sep 21, 2023

Not sure whether I understood the motivation for this. Loops are not outlined and generally they are inlined in the function or the parallel region where they exist. Hence the allocations for the loop index will be in the function or parallel region where they exist. Hence they can (hopefully) be promoted to registers by llvm passes.

Also, wouldn't moving allocations inside the loop lead to runaway allocations?

@DominikAdamski
Copy link
Contributor

DominikAdamski commented Sep 22, 2023

Hi,
We would like to introduce new loop sharing functions for offloaded code. The main idea is as follows:

Input code:

 !$omp target parallel do
   do i = 1, 1024
      e(i) = i
   end do

Corresponding MLIR code:

omp.target   map((tofrom -> %arg0 : !llvm.ptr<array<1024 x i32>>)) {
   omp.parallel   {
   **%i = llvm.alloca (...)** ;we would like to move this alloca to the loop body
         omp.wsloop   for  (%arg1) : i32 = (%2) to (%1) inclusive step (%2) {
           llvm.store %arg1, %4 {tbaa = [#tbaa_tag]} : !llvm.ptr<i32>
          %5 = llvm.load %i {tbaa = [#tbaa_tag]} : !llvm.ptr<i32>
          %6 = llvm.sext %5 : i32 to i64
          %7 = llvm.sub %6, %0  : i64
          %8 = llvm.getelementptr %arg0[0, %7] : (!llvm.ptr<array<1024 x i32>>, i64) -> !llvm.ptr<i32>
          llvm.store %5, %8 {tbaa = [#tbaa_tag]} : !llvm.ptr<i32>
          omp.yield
        }
        omp.terminator
      }
      omp.terminator
    }

Desired LLVM IR:

gpu_kernel target_kernel (ptr %arg_e) {
 ; kmpc_target_init
 call __kmpc_parallel_51 (parallel_function, arg_e);
 ; kmpc_target_init
}

; function which corresponds to the MLIR parallel region
void parallel_function (ptr %tid.addr, ptr %zero.addr, ptr %arg_e) {
   ; no code which corresponds to for-cond; loop counter etc. only setting arguments for kmp_for_static_loop_4
    call __kmpc_for_static_loop_4(loop_body_function, num_iters, arg_e, ...)
}

;function which corresponds to the wsloop body:
void loop_body_function(int32 %cnt , struct {ptr %arg_e}) {
    ;LLVM-IR code which corresponds to **e[i] = i**
    ; **i** from source depends on cnt which is set by the function __kmpc_for_static_loop_4
   ; we do not need to pass variable **i** as the part of the second argument
   ; variable cnt which is handled by OpenMP runtime is enough for performing loop body code 
}

Without Sergio's patch we get the following LLVM IR:

gpu_kernel target_kernel (ptr %arg_e) {
 ; kmpc_target_init
 call __kmpc_parallel_51 (parallel_function, arg_e);
 ; kmpc_target_init
}

; function which corresponds to the MLIR parallel region
void parallel_function (ptr %tid.addr, ptr %zero.addr, ptr %arg_e) {
     ; no code which corresponds to for-cond; loop counter etc. only setting arguments for kmp_for_static_loop_4
    **%additional_i** = alloca i32;
    call __kmpc_for_static_loop_4(loop_body_function, num_iters, **struct{arg_e, additional_i}** ...)
}

;function which corresponds to the wsloop body:
void loop_body_function(int32 %cnt , **ptr struct{ ptr %arg_e, additional_i}** ) {
    ;LLVM-IR code which corresponds to **e[i] = i**
    ;**store new value inside additional_i
    ; **i** from source code depends on cnt which is set by the function __kmpc_for_static_loop_4
}

Link to the similar OpenMP DeviceRTL functions: https://github.com/jdoerfert/llvm-project/blob/IPDPS22/openmp/libomptarget/DeviceRTL/src/Workshare.cpp

@skatrak
Copy link
Member Author

skatrak commented Sep 22, 2023

Just as an additional note, what we actually want is to remove the alloca and use the int32 %cnt value instead whenever possible. But if the address of the variable is passed to another function from inside of the loop body (in fact, that would be the default in Fortran if we called some subprocedure with the index variable as argument), then that wouldn't be possible. That is the case for which we'd be looking to bring the alloca inside of the loop body, to at least avoid passing it from outside and just have it as a stack variable. I hope between the comments we're making the need for this pass or something with the same effect a bit clearer.

If a pass like this were to be the agreed solution, it seems like we should add the AutomaticAllocationScope trait to the OpenMP loop operations impacted by it (omp.wsloop, omp.simdloop and maybe omp.taskloop or others), which I suppose should address your concern about runaway allocations. But we've discussed that maybe there is a way we can rely on existing LLVM infrastructure to address this: the CodeExtractor class that would be used here seems like it should be able to identify the outside alloca as something that can be sunk into the outlined function, and once that is done I think it's safe to assume that at some point in the LLVM optimization pipeline there would be a pass that would identify this allocation as redundant if it was only used to store the contents of the index, which is already available as an SSA value. I'm currently looking into this possibility and I'll close this PR if I find that a simpler approach is possible.

@kiranchandramohan
Copy link
Contributor

The other option is to delay materializing this loop index variable till LLVM code generation. You can have another entry block argument that is a reference. In the IR we will store the loop index to the reference. While generating LLVM code, we can materialize this to an llvm alloca wherever we want.

omp.wslop %i = %start to %end {
bb0 (%i, %i_ref):
   fir.store %i to %i_ref
   }

@skatrak
Copy link
Member Author

skatrak commented Sep 28, 2023

I've got an update on this. I made a few changes (< 100 lines) in the OpenMP dialect to LLVM IR translation pass to detect allocas for primitive types which are only used inside of the region associated to OpenMP loop operations (omp.wsloop and omp.simdloop), and to add llvm.lifetime.start and llvm.lifetime.end markers for them around the entry and exit points of the associated basic block(s).

These markers later allow the CodeExtractor to sink the allocas for index variables (and potentially others) into the outlined function created for the loop body, rather than passing them as arguments. The outlining would be implemented by the OpenMPIRBuilder after the new RTL functions mentioned by Dominik were available.

My idea was to wait until at least the PRs adding these other functions were up before creating one for my changes, since it wouldn't make much of a real difference on its own, but I wanted to check whether it would be a good idea to still create it now so we can discuss whether that or a change to the MLIR representation of the OpenMP loop operations would be a better solution.

skatrak added a commit to skatrak/llvm-project that referenced this pull request Dec 8, 2023
…d allocations

This patch introduces `llvm.lifetime.start` and `llvm.lifetime.end` markers
around the LLVM basic blocks containing the translated body of `omp.wsloop` and
`omp.simdloop` operations, for all `alloca` instructions that are defined
outside of that block but only ever used inside of it.

This is achieved by analyzing the MLIR regions associated to the aforementioned
OpenMP dialect loop operations during translation to LLVM IR. The purpose of
this addition is to enable sinking these allocations into the region if it gets
outlined into a separate function, avoiding the need to pass the pointer as an
argument.

It is a less intrusive alternative to llvm#67010 that addresses the same problem
on the interaction between redundant allocations for OpenMP loop indices and
loop body outlining for target offload using new DeviceRTL functions.
@skatrak
Copy link
Member Author

skatrak commented Jan 15, 2024

Closing due to better solutions being discussed.

@skatrak skatrak closed this Jan 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants