Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR #111155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kaviya2510
Copy link

This patch supports lowering of task_reduction and in_reduction to MLIR

Copy link

github-actions bot commented Oct 4, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Oct 4, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kaviya Rajendiran (kaviya2510)

Changes

This patch supports lowering of task_reduction and in_reduction to MLIR


Patch is 22.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111155.diff

8 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+70-1)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+10)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+41-13)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+46-11)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+2-1)
  • (added) flang/test/Lower/OpenMP/task_array_reduction.f90 (+50)
  • (added) flang/test/Lower/OpenMP/task_in_reduction.f90 (+48)
  • (added) flang/test/Lower/OpenMP/task_reduction.f90 (+43)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a4d2524bccf5c3..95ab51809dcf94 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1063,7 +1063,7 @@ bool ClauseProcessor::processReduction(
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
         llvm::SmallVector<const semantics::Symbol *> reductionSyms;
         ReductionProcessor rp;
-        rp.addDeclareReduction(
+        rp.addDeclareReduction<omp::clause::Reduction>(
             currentLocation, converter, clause, reductionVars, reduceVarByRef,
             reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
 
@@ -1085,6 +1085,75 @@ bool ClauseProcessor::processReduction(
       });
 }
 
+bool ClauseProcessor::processTaskReduction(
+    mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
+    llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
+    llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
+  return findRepeatableClause<omp::clause::TaskReduction>(
+      [&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
+        llvm::SmallVector<mlir::Value> taskReductionVars;
+        llvm::SmallVector<bool> taskReductionByref;
+        llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
+        llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
+        ReductionProcessor rp;
+        rp.addDeclareReduction<omp::clause::TaskReduction>(
+            currentLocation, converter, clause, taskReductionVars, taskReductionByref,
+            taskReductionDeclSymbols, outReductionSyms ? &taskReductionSyms : nullptr);
+
+        // Copy local lists into the output.
+        llvm::copy(taskReductionVars, std::back_inserter(result.taskReductionVars));
+        llvm::copy(taskReductionByref, std::back_inserter(result.taskReductionByref));
+        llvm::copy(taskReductionDeclSymbols,
+                   std::back_inserter(result.taskReductionSyms));
+
+        if (outReductionTypes) {
+          outReductionTypes->reserve(outReductionTypes->size() +
+                                     taskReductionVars.size());
+          llvm::transform(taskReductionVars, std::back_inserter(*outReductionTypes),
+                          [](mlir::Value v) { return v.getType(); });
+        }
+
+        if (outReductionSyms)
+          llvm::copy(taskReductionSyms, std::back_inserter(*outReductionSyms));
+      });
+}
+
+bool ClauseProcessor::processInReduction(
+    mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
+    llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
+    llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
+  return findRepeatableClause<omp::clause::InReduction>(
+      [&](const omp::clause::InReduction &clause,
+          const parser::CharBlock &source) {
+        llvm::SmallVector<mlir::Value> inReductionVars;
+        llvm::SmallVector<bool> inReductionByref;
+        llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
+        llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
+        ReductionProcessor rp;
+        rp.addDeclareReduction<omp::clause::InReduction>(
+            currentLocation, converter, clause, inReductionVars,
+            inReductionByref, inReductionDeclSymbols,
+            outReductionSyms ? &inReductionSyms : nullptr);
+
+        // Copy local lists into the output.
+        llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
+        llvm::copy(inReductionByref, std::back_inserter(result.inReductionByref));
+        llvm::copy(inReductionDeclSymbols,
+                   std::back_inserter(result.inReductionSyms));
+
+        if (outReductionTypes) {
+          outReductionTypes->reserve(outReductionTypes->size() +
+                                     inReductionVars.size());
+          llvm::transform(inReductionVars,
+                          std::back_inserter(*outReductionTypes),
+                          [](mlir::Value v) { return v.getType(); });
+        }
+
+        if (outReductionSyms)
+          llvm::copy(inReductionSyms, std::back_inserter(*outReductionSyms));
+      });
+}
+
 bool ClauseProcessor::processTo(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::To>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 0c8e7bd47ab5a6..04416a927a1c37 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -129,6 +129,16 @@ class ClauseProcessor {
       llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
       llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
           nullptr) const;
+  bool processTaskReduction(
+      mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
+      llvm::SmallVectorImpl<mlir::Type> *taskReductionTypes = nullptr,
+      llvm::SmallVectorImpl<const semantics::Symbol *> *taskReductionSyms =
+          nullptr) const;
+  bool processInReduction(
+      mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
+      llvm::SmallVectorImpl<mlir::Type> *inReductionTypes = nullptr,
+      llvm::SmallVectorImpl<const semantics::Symbol *> *inReductionSyms =
+          nullptr) const;
   bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
   bool processUseDeviceAddr(
       lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 60c83586e468b6..850f32ff0bf030 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1244,29 +1244,34 @@ static void genTaskClauses(lower::AbstractConverter &converter,
                            semantics::SemanticsContext &semaCtx,
                            lower::StatementContext &stmtCtx,
                            const List<Clause> &clauses, mlir::Location loc,
-                           mlir::omp::TaskOperands &clauseOps) {
+                           mlir::omp::TaskOperands &clauseOps,
+                           llvm::SmallVectorImpl<mlir::Type> &inReductionTypes,
+                           llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processAllocate(clauseOps);
   cp.processDepend(clauseOps);
   cp.processFinal(stmtCtx, clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+  cp.processInReduction(loc, clauseOps, &inReductionTypes,
+                          &inReductionSyms);
   cp.processMergeable(clauseOps);
   cp.processPriority(stmtCtx, clauseOps);
   cp.processUntied(clauseOps);
   // TODO Support delayed privatization.
 
-  cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
+  cp.processTODO<clause::Affinity, clause::Detach>(
       loc, llvm::omp::Directive::OMPD_task);
 }
 
 static void genTaskgroupClauses(lower::AbstractConverter &converter,
-                                semantics::SemanticsContext &semaCtx,
-                                const List<Clause> &clauses, mlir::Location loc,
-                                mlir::omp::TaskgroupOperands &clauseOps) {
+    semantics::SemanticsContext &semaCtx,
+    const List<Clause> &clauses, mlir::Location loc,
+    mlir::omp::TaskgroupOperands &clauseOps,
+    llvm::SmallVectorImpl<mlir::Type> &taskReductionTypes,
+    llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processAllocate(clauseOps);
-  cp.processTODO<clause::TaskReduction>(loc,
-                                        llvm::omp::Directive::OMPD_taskgroup);
+  cp.processTaskReduction(loc, clauseOps, &taskReductionTypes, &taskReductionSyms);
 }
 
 static void genTaskwaitClauses(lower::AbstractConverter &converter,
@@ -1866,13 +1871,26 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
           ConstructQueue::const_iterator item) {
   lower::StatementContext stmtCtx;
   mlir::omp::TaskOperands clauseOps;
-  genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+  llvm::SmallVector<mlir::Type> inReductionTypes;
+  llvm::SmallVector<const semantics::Symbol *> inreductionSyms;
+  genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
+                 inReductionTypes, inreductionSyms);
+
+  auto reductionCallback = [&](mlir::Operation *op) {
+    genReductionVars(op, converter, loc, inreductionSyms, inReductionTypes);
+    return inreductionSyms;
+  };
 
-  return genOpWithBody<mlir::omp::TaskOp>(
+  auto taskOp = genOpWithBody<mlir::omp::TaskOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
                         llvm::omp::Directive::OMPD_task)
-          .setClauses(&item->clauses),
+          .setClauses(&item->clauses)
+          .setGenRegionEntryCb(reductionCallback),
       queue, item, clauseOps);
+  // Add reduction variables as arguments
+  llvm::SmallVector<mlir::Location> blockArgLocs(inReductionTypes.size(), loc);
+  taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
+  return taskOp;
 }
 
 static mlir::omp::TaskgroupOp
@@ -1882,13 +1900,21 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                const ConstructQueue &queue,
                ConstructQueue::const_iterator item) {
   mlir::omp::TaskgroupOperands clauseOps;
-  genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+  llvm::SmallVector<mlir::Type> taskReductionTypes;
+  llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
+  genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
+                      taskReductionTypes, taskReductionSyms);
 
-  return genOpWithBody<mlir::omp::TaskgroupOp>(
+  auto taskgroupOp = genOpWithBody<mlir::omp::TaskgroupOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
                         llvm::omp::Directive::OMPD_taskgroup)
           .setClauses(&item->clauses),
       queue, item, clauseOps);
+
+  // Add reduction variables as arguments
+  llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
+  taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
+  return taskgroupOp;
 }
 
 static mlir::omp::TaskwaitOp
@@ -2764,7 +2790,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::ThreadLimit>(clause.u) &&
         !std::holds_alternative<clause::Threads>(clause.u) &&
         !std::holds_alternative<clause::UseDeviceAddr>(clause.u) &&
-        !std::holds_alternative<clause::UseDevicePtr>(clause.u)) {
+        !std::holds_alternative<clause::UseDevicePtr>(clause.u) &&
+        !std::holds_alternative<clause::TaskReduction>(clause.u) &&
+        !std::holds_alternative<clause::InReduction>(clause.u)) {
       TODO(clauseLocation, "OpenMP Block construct clause");
     }
   }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 9da15ba303a475..deb25b4fff3792 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -24,6 +24,7 @@
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/Support/CommandLine.h"
+#include <type_traits>
 
 static llvm::cl::opt<bool> forceByrefReduction(
     "force-byref-reduction",
@@ -34,6 +35,38 @@ namespace Fortran {
 namespace lower {
 namespace omp {
 
+// explicit template declarations
+template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
+        mlir::Location currentLocation,
+        lower::AbstractConverter &converter,
+        const omp::clause::Reduction &reduction,
+        llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+        llvm::SmallVectorImpl<bool> &reduceVarByRef,
+        llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+        llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols
+    );
+
+template void ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
+        mlir::Location currentLocation,
+        lower::AbstractConverter &converter,
+        const omp::clause::TaskReduction &reduction,
+        llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+        llvm::SmallVectorImpl<bool> &reduceVarByRef,
+        llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+        llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols
+    );
+
+template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
+        mlir::Location currentLocation,
+        lower::AbstractConverter &converter,
+        const omp::clause::InReduction &reduction,
+        llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+        llvm::SmallVectorImpl<bool> &reduceVarByRef,
+        llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
+        llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols
+    );
+
+
 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
     const omp::clause::ProcedureDesignator &pd) {
   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -716,24 +749,26 @@ static bool doReductionByRef(mlir::Value reductionVar) {
   return false;
 }
 
-void ReductionProcessor::addDeclareReduction(
-    mlir::Location currentLocation, lower::AbstractConverter &converter,
-    const omp::clause::Reduction &reduction,
+template <class T>
+void ReductionProcessor::addDeclareReduction(mlir::Location currentLocation,
+    lower::AbstractConverter &converter,
+    const T &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
     llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
-          reduction.t))
-    TODO(currentLocation, "Reduction modifiers are not supported");
+  if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
+      if (std::get<std::optional<typename T::ReductionModifier>>(
+            reduction.t))
+      TODO(currentLocation, "Reduction modifiers are not supported");
+    }
 
   mlir::omp::DeclareReductionOp decl;
-  const auto &redOperatorList{
-      std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
-  assert(redOperatorList.size() == 1 && "Expecting single operator");
-  const auto &redOperator = redOperatorList.front();
+    const auto &redOperatorList{
+      std::get<typename T::ReductionIdentifiers>(reduction.t)};
+    assert(redOperatorList.size() == 1 && "Expecting single operator");
+    const auto &redOperator = redOperatorList.front();
   const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
 
   if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 0ed5782e5da1b7..d34db0618c7cda 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -120,9 +120,10 @@ class ReductionProcessor {
 
   /// Creates a reduction declaration and associates it with an OpenMP block
   /// directive.
+  template <class T>
   static void addDeclareReduction(
       mlir::Location currentLocation, lower::AbstractConverter &converter,
-      const omp::clause::Reduction &reduction,
+      const T &reduction,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
       llvm::SmallVectorImpl<bool> &reduceVarByRef,
       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
diff --git a/flang/test/Lower/OpenMP/task_array_reduction.f90 b/flang/test/Lower/OpenMP/task_array_reduction.f90
new file mode 100644
index 00000000000000..74693343744c26
--- /dev/null
+++ b/flang/test/Lower/OpenMP/task_array_reduction.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! CHECK-LABEL:  omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref<!fir.box<!fir.array<?xf32>>> alloc {
+! [...]
+! CHECK:  omp.yield
+! CHECK-LABEL:  } init {
+! [...]
+! CHECK:  omp.yield
+! CHECK-LABEL:  } combiner {
+! [...]
+! CHECK:  omp.yield
+! CHECK-LABEL:  }  cleanup {
+! [...]
+! CHECK:  omp.yield
+! CHECK:  }
+
+! CHECK-LABEL:  func.func @_QPtaskreduction
+! CHECK-SAME:  (%[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
+! CHECK:  %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK:  %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]]
+! CHECK-SAME  {uniq_name = "_QFtaskreductionEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
+! CHECK:  omp.parallel {
+! CHECK:  %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
+! CHECK:  fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
+! CHECK:  omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] ->  %[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>) {
+! CHECK:  %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
+! CHECK:  fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
+! CHECK:  omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
+! [...]
+! CHECK:  omp.terminator
+! CHECK:  }
+! CHECK:  omp.terminator
+! CHECK:  }
+! CHECK:  omp.terminator
+! CHECK:  }
+! CHECK:  return
+! CHECK:  }
+
+subroutine taskReduction(x)
+   real, dimension(:) :: x
+   !$omp parallel
+   !$omp taskgroup task_reduction(+:x)
+   !$omp task in_reduction(+:x)
+   x = x + 1
+   !$omp end task
+   !$omp end taskgroup
+   !$omp end parallel
+end subroutine
+
diff --git a/flang/test/Lower/OpenMP/task_in_reduction.f90 b/flang/test/Lower/OpenMP/task_in_reduction.f90
new file mode 100644
index 00000000000000..26c079d5ac8aa5
--- /dev/null
+++ b/flang/test/Lower/OpenMP/task_in_reduction.f90
@@ -0,0 +1,48 @@
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK:  %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK:  omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL:  func.func @_QPin_reduction() {
+!CHECK:  %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFin_reductionEx"}
+!CHECK:  %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFin_reductionEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:  %[[VAL_2:.*]] = arith.constant 0 : i32
+!CHECK:  hlfir.assign %[[VAL_2]] to %[[VAL_1]]#0 : i32, !fir.ref<i32>
+!CHECK:  omp.parallel {
+!CHECK:  omp.taskgroup task_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_3:.*]] : !fir.ref<i32>) {
+!CHECK:  omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_4:.*]] : !fir.ref<i32>) {
+!CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = "_QFin_reductionEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+!CHECK:  %[[VAL_7:.*]] = arith.constant 1 : i32
+!CHECK:  %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32
+!CHECK:  hlfir.assign %[[VAL_8]] to %[[VAL_5]]#0 : i32, !fir.ref<i32>
+!CHECK:  omp.terminator
+!CHECK:  }
+!CHECK:  omp.terminator
+!CHECK:  }
+!CHECK:  omp.terminator
+!CHECK:  }
+!CHECK:  return
+!CHECK:  }
+
+subroutine in_reduction
+   integer :: x
+   x = 0
+   !$omp parallel
+   !$omp taskgroup task_reduction(+:x)
+   !$omp task in_reduction(+:x)
+   x = x + 1
+   !$omp end task
+   !$omp end taskgroup
+   !$omp end parallel
+end subroutine
+
diff --git...
[truncated]

Copy link

github-actions bot commented Oct 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the translation from OpenMP dialect to LLVMIR before lowering to MLIR. Otherwise this will manifest as a crash.

@kaviya2510 kaviya2510 changed the title Support for lowering task_reduction and in_reduction to MLIR [Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR Oct 4, 2024
@kaviya2510
Copy link
Author

Please add the translation from OpenMP dialect to LLVMIR before lowering to MLIR. Otherwise this will manifest as a crash.

Sure, I will do it.

@kaviya2510 kaviya2510 force-pushed the task_reduction branch 3 times, most recently from 3247466 to 8aeb419 Compare October 7, 2024 07:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants