Skip to content

[WIP][PoC][flang] Re-use OpenMP data environemnt clauses for locality spec #128148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Feb 21, 2025

This is a PoC to write a proper RFC based on later. This is not meant to be merged!

Now that we started working on mapping do concurrent loop nests to corresponding OpenMP constructs (and later to OpenACC), we come across the following problem: How can we map do concurrent's locality specifiers to their corresponding OpenMP/ACC data environment clauses?

This is not easy at the moment because locality specifiers are handled on the PFT to MLIR lowering level which makes discovering the ops corresponding to them more difficult (or even not possible) during do concurrent to OpenMP mapping.

One way to handle this problem would be to use something similar to delayed privatization that we have been working on for the OpenMP dialect recently. So on the MLIR level, the following do concurrent loop:

subroutine foo
  implicit none
  integer :: i, local_var!, local_init_var

  do concurrent (i=1:10) local(local_var) local_init(local_init_var)
    if (i < 5) then
      local_var = 42
    else 
      !local_init_var = 84
    end if
  end do
end subroutine

would look something like this:

    %0 = fir.alloca i32 {bindc_name = "i"}
    %1:2 = hlfir.declare %0 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomploopEi"}
    %3:2 = hlfir.declare %2 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %4 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFomploopElocal_init_var"}
    %5:2 = hlfir.declare %4 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %6 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFomploopElocal_var"}
    %7:2 = hlfir.declare %6 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
    %c1_i32 = arith.constant 1 : i32
    %8 = fir.convert %c1_i32 : (i32) -> index
    %c10_i32 = arith.constant 10 : i32
    %9 = fir.convert %c10_i32 : (i32) -> index
    %c1 = arith.constant 1 : index
    // Instead of using "private" we can use "local".
    fir.do_loop %arg0 = %8 to %9 step %c1 unordered private(@local_privatizer %7#0 -> %arg1, @local_init_privatizer %5#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
      %10 = fir.convert %arg0 : (index) -> i32
      fir.store %10 to %1#1 : !fir.ref<i32>
      %11:2 = hlfir.declare %arg1 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      %12:2 = hlfir.declare %arg2 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      %13 = fir.load %1#0 : !fir.ref<i32>
      %c5_i32 = arith.constant 5 : i32
      %14 = arith.cmpi slt, %13, %c5_i32 : i32
      fir.if %14 {
        %c42_i32 = arith.constant 42 : i32
        hlfir.assign %c42_i32 to %11#0 : i32, !fir.ref<i32>
      } else {
        %c84_i32 = arith.constant 84 : i32
        hlfir.assign %c84_i32 to %12#0 : i32, !fir.ref<i32>
      }
    }

To that end, it would be nice to:

  1. Extract the table-gen records we already for OpenMP into a separate "Data Environment" dialect.
  2. Use the records in that dialect for both of OpenMP and do concurrent (and later for OpenACC).
  3. We can do this hopefully for both local/private-related clauses/specifiers as well as reduction.

This is a PoC to validate that idea. For now it only reuses the OpenMP stuff just to showcase how it looks like for do concurrent. The PoC contains a sample to test the current prototyped functionality.

Current status of the PoC:

  • Extend fir.do_loop to reuse OpenMP clause table-gen records
  • Parsing and printing for fir.do_loop with private specifiers
  • Basic lowering of fir.do_loop local specifiers
  • Basic lowering of fir.do_loop's local_init specifier
  • PFT to MLIR lowring using the MLIR locality specifiers.

Each of the items above has a corresponding self-contained commit to demo the needed changes in that part of the pipeline.

@ergawy ergawy marked this pull request as draft February 21, 2025 08:35
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This is a PoC to write a proper RFC based on later. This is not meant to be merged!

Now that we started working on mapping do concurrent loop nests to corresponding OpenMP constructs (and later to OpenACC), we come across the following problem: How can we map do concurrent's locality specifiers to their corresponding OpenMP/ACC data environment clauses?

This is not easy at the moment because locality specifiers are handled on the PFT to MLIR lowering level which makes discovering the ops corresponding to them more difficult (or even not possible) during do concurrent to OpenMP mapping.

One way to handle this problem would be use something similar to delayed privatization that we have been working on for the OpenMP dialect recently. So on the MLIR level, the following do concurrent loop:

subroutine foo
  implicit none
  integer :: i, local_var!, local_init_var

  do concurrent (i=1:10) local(local_var) local_init(local_init_var)
    if (i &lt; 5) then
      local_var = 42
    else 
      !local_init_var = 84
    end if
  end do
end subroutine

would look something like this:

    %0 = fir.alloca i32 {bindc_name = "i"}
    %1:2 = hlfir.declare %0 {uniq_name = "_QFomploopEi"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
    %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomploopEi"}
    %3:2 = hlfir.declare %2 {uniq_name = "_QFomploopEi"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
    %4 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFomploopElocal_init_var"}
    %5:2 = hlfir.declare %4 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
    %6 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFomploopElocal_var"}
    %7:2 = hlfir.declare %6 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
    %c1_i32 = arith.constant 1 : i32
    %8 = fir.convert %c1_i32 : (i32) -&gt; index
    %c10_i32 = arith.constant 10 : i32
    %9 = fir.convert %c10_i32 : (i32) -&gt; index
    %c1 = arith.constant 1 : index
    // Instead of using "private" we can use "local".
    fir.do_loop %arg0 = %8 to %9 step %c1 unordered private(@<!-- -->local_privatizer %7#<!-- -->0 -&gt; %arg1, @<!-- -->local_init_privatizer %5#<!-- -->0 -&gt; %arg2 : !fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;) {
      %10 = fir.convert %arg0 : (index) -&gt; i32
      fir.store %10 to %1#<!-- -->1 : !fir.ref&lt;i32&gt;
      %11:2 = hlfir.declare %arg1 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
      %12:2 = hlfir.declare %arg2 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref&lt;i32&gt;) -&gt; (!fir.ref&lt;i32&gt;, !fir.ref&lt;i32&gt;)
      %13 = fir.load %1#<!-- -->0 : !fir.ref&lt;i32&gt;
      %c5_i32 = arith.constant 5 : i32
      %14 = arith.cmpi slt, %13, %c5_i32 : i32
      fir.if %14 {
        %c42_i32 = arith.constant 42 : i32
        hlfir.assign %c42_i32 to %11#<!-- -->0 : i32, !fir.ref&lt;i32&gt;
      } else {
        %c84_i32 = arith.constant 84 : i32
        hlfir.assign %c84_i32 to %12#<!-- -->0 : i32, !fir.ref&lt;i32&gt;
      }
    }

To that end, it would be nice to:

  1. Extract the table-gen records we already for OpenMP into a separate "Data Environment" dialect.
  2. Use the records in that dialect for both of OpenMP and do concurrent (and later for OpenACC).
  3. We can do this hopefully for both local/private-related clauses/specifiers as well as reduction.

This is a PoC to validate that idea. For now it only reuses the OpenMP stuff just to showcase how it looks like for do concurrent. The PoC contains a sample to test the current prototyped functionality.

Current status of the PoC:

  • Extend fir.do_loop to reuse OpenMP clause table-gen records
  • Parsing and printing for fir.do_loop with private specifiers
  • Basic lowering of fir.do_loop local specifiers
  • Basic lowering of fir.do_loop's local_init specifier
  • PFT to MLIR lowring using the MLIR locality specifiers.

Each of the checked items above has a corresponding self-contained commit to demo the needed changes in that part of the pipeline.


Full diff: https://github.com/llvm/llvm-project/pull/128148.diff

5 Files Affected:

  • (added) do_loop_with_local_and_local_init.mlir (+49)
  • (modified) flang/include/flang/Optimizer/Dialect/CMakeLists.txt (+2-2)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+29-8)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+89-17)
  • (modified) flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp (+57)
diff --git a/do_loop_with_local_and_local_init.mlir b/do_loop_with_local_and_local_init.mlir
new file mode 100644
index 0000000000000..06510b4433f1a
--- /dev/null
+++ b/do_loop_with_local_and_local_init.mlir
@@ -0,0 +1,49 @@
+// For testing:
+// 1. parsing/printing (roundtripping): `fir-opt do_loop_with_local_and_local_init.mlir -o roundtrip.mlir`
+// 2. Lowering locality specs during CFG: `fir-opt --cfg-conversion do_loop_with_local_and_local_init.mlir -o after_cfg_lowering.mlir`
+
+// TODO I will add both of the above steps as proper tests when the PoC is complete.
+module attributes {dlti.dl_spec = #dlti.dl_spec<i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 21.0.0 (/home/kaergawy/git/aomp20.0/llvm-project/flang c8cf5a644886bb8dd3ad19be6e3b916ffcbd222c)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+
+  omp.private {type = private} @local_privatizer : i32
+
+  omp.private {type = firstprivate} @local_init_privatizer : i32 copy {
+  ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
+      %0 = fir.load %arg0 : !fir.ref<i32>
+      fir.store %0 to %arg1 : !fir.ref<i32>
+      omp.yield(%arg1 : !fir.ref<i32>)
+  }
+
+  func.func @_QPomploop() {
+    %0 = fir.alloca i32 {bindc_name = "i"}
+    %1:2 = hlfir.declare %0 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFomploopEi"}
+    %3:2 = hlfir.declare %2 {uniq_name = "_QFomploopEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %4 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFomploopElocal_init_var"}
+    %5:2 = hlfir.declare %4 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %6 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFomploopElocal_var"}
+    %7:2 = hlfir.declare %6 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    %c1_i32 = arith.constant 1 : i32
+    %8 = fir.convert %c1_i32 : (i32) -> index
+    %c10_i32 = arith.constant 10 : i32
+    %9 = fir.convert %c10_i32 : (i32) -> index
+    %c1 = arith.constant 1 : index
+    fir.do_loop %arg0 = %8 to %9 step %c1 unordered private(@local_privatizer %7#0 -> %arg1, @local_init_privatizer %5#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
+      %10 = fir.convert %arg0 : (index) -> i32
+      fir.store %10 to %1#1 : !fir.ref<i32>
+      %12:2 = hlfir.declare %arg1 {uniq_name = "_QFomploopElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %14:2 = hlfir.declare %arg2 {uniq_name = "_QFomploopElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %16 = fir.load %1#0 : !fir.ref<i32>
+      %c5_i32 = arith.constant 5 : i32
+      %17 = arith.cmpi slt, %16, %c5_i32 : i32
+      fir.if %17 {
+        %c42_i32 = arith.constant 42 : i32
+        hlfir.assign %c42_i32 to %12#0 : i32, !fir.ref<i32>
+      } else {
+        %c84_i32 = arith.constant 84 : i32
+        hlfir.assign %c84_i32 to %14#0 : i32, !fir.ref<i32>
+      }
+    }
+    return
+  }
+}
diff --git a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
index 73f388cbab6c9..da14fcd25a8d3 100644
--- a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
@@ -16,8 +16,8 @@ mlir_tablegen(FIRAttr.cpp.inc -gen-attrdef-defs)
 set(LLVM_TARGET_DEFINITIONS FIROps.td)
 mlir_tablegen(FIROps.h.inc -gen-op-decls)
 mlir_tablegen(FIROps.cpp.inc -gen-op-defs)
-mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls)
-mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs)
+mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls -typedefs-dialect=fir)
+mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs -typedefs-dialect=fir)
 add_public_tablegen_target(FIROpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS FortranVariableInterface.td)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553d..34647263d6cc7 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -16,6 +16,7 @@
 
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
+include "mlir/Dialect/OpenMP/OpenMPClauses.td"
 include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -2171,7 +2172,7 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
   let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 
-  let arguments = (ins
+  defvar opArgs = (ins
     Index:$lowerBound,
     Index:$upperBound,
     Index:$step,
@@ -2182,6 +2183,8 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
     OptionalAttr<ArrayAttr>:$reduceAttrs,
     OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
   );
+
+  let arguments = !con(opArgs, OpenMP_PrivateClause.arguments);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
@@ -2193,24 +2196,38 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
       CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
       CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
       CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
-      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
+      CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes,
+      CArg<"mlir::ValueRange", "std::nullopt">:$private_vars,
+      CArg<"mlir::ArrayRef<mlir::Attribute>", "{}">:$private_syms
+      )>
   ];
 
-  let extraClassDeclaration = [{
-    mlir::Value getInductionVar() { return getBody()->getArgument(0); }
+  defvar opExtraClassDeclaration = [{
     mlir::OpBuilder getBodyBuilder() {
       return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
     }
+
+    /// Region argument accessors.
+    mlir::Value getInductionVar() { return getBody()->getArgument(0); }
     mlir::Block::BlockArgListType getRegionIterArgs() {
-      return getBody()->getArguments().drop_front();
+      // 1 for skipping the induction variable.
+      return getBody()->getArguments().slice(1, getNumIterOperands());
+    }
+    mlir::Block::BlockArgListType getRegionPrivateArgs() {
+     return getBody()->getArguments().slice(1 + getNumIterOperands(),
+                                            numPrivateBlockArgs());
     }
+
+    /// Operation operand accessors.
     mlir::Operation::operand_range getIterOperands() {
       return getOperands()
-          .drop_front(getNumControlOperands() + getNumReduceOperands());
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumIterOperands());
     }
     llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
       return getOperation()->getOpOperands()
-          .drop_front(getNumControlOperands() + getNumReduceOperands());
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumIterOperands());
     }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2219,7 +2236,7 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
 
     /// Number of region arguments for loop-carried values
     unsigned getNumRegionIterArgs() {
-      return getBody()->getNumArguments() - 1;
+      return getNumIterOperands();
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
@@ -2258,6 +2275,10 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
                            unsigned resultNum);
     mlir::Value blockArgToSourceOp(unsigned blockArgNum);
   }];
+
+  let extraClassDeclaration =
+    !strconcat(opExtraClassDeclaration, "\n",
+               OpenMP_PrivateClause.extraClassDeclaration);
 }
 
 def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 7e50622db08c9..c729414cd2393 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2478,14 +2478,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
                           bool finalCountValue, mlir::ValueRange iterArgs,
                           mlir::ValueRange reduceOperands,
                           llvm::ArrayRef<mlir::Attribute> reduceAttrs,
-                          llvm::ArrayRef<mlir::NamedAttribute> attributes) {
+                          llvm::ArrayRef<mlir::NamedAttribute> attributes,
+                          mlir::ValueRange privateVars,
+                          mlir::ArrayRef<mlir::Attribute> privateSyms) {
   result.addOperands({lb, ub, step});
   result.addOperands(reduceOperands);
   result.addOperands(iterArgs);
   result.addAttribute(getOperandSegmentSizeAttr(),
                       builder.getDenseI32ArrayAttr(
                           {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
-                           static_cast<int32_t>(iterArgs.size())}));
+                           static_cast<int32_t>(iterArgs.size()), 0}));
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
     result.addAttribute(getFinalValueAttrName(result.name),
@@ -2561,8 +2563,9 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
 
   // Parse the optional initial iteration arguments.
   llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
-  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   llvm::SmallVector<mlir::Type> argTypes;
+
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   bool prependCount = false;
   regionArgs.push_back(inductionVariable);
 
@@ -2587,15 +2590,6 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     prependCount = true;
   }
 
-  // Set the operandSegmentSizes attribute
-  result.addAttribute(getOperandSegmentSizeAttr(),
-                      builder.getDenseI32ArrayAttr(
-                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
-                           static_cast<int32_t>(iterOperands.size())}));
-
-  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
-    return mlir::failure();
-
   // Induction variable.
   if (prependCount)
     result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name),
@@ -2604,15 +2598,77 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     argTypes.push_back(indexType);
   // Loop carried variables
   argTypes.append(result.types.begin(), result.types.end());
-  // Parse the body region.
-  auto *body = result.addRegion();
+
   if (regionArgs.size() != argTypes.size())
     return parser.emitError(
         parser.getNameLoc(),
         "mismatch in number of loop-carried values and defined values");
+
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
+  if (succeeded(parser.parseOptionalKeyword("private"))) {
+    std::size_t oldArgTypesSize = argTypes.size();
+    if (failed(parser.parseLParen()))
+      return mlir::failure();
+
+    llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
+            return mlir::failure();
+
+          if (parser.parseOperand(privateOperands.emplace_back()) ||
+              parser.parseArrow() ||
+              parser.parseArgument(regionArgs.emplace_back()))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (failed(parser.parseColon()))
+      return mlir::failure();
+
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseType(argTypes.emplace_back())))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (regionArgs.size() != argTypes.size())
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of private arg and types");
+
+    if (failed(parser.parseRParen()))
+      return mlir::failure();
+
+    for (auto operandType : llvm::zip_equal(
+             privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
+        return mlir::failure();
+
+    llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
+                                                   privateSymbolVec.end());
+    result.addAttribute(getPrivateSymsAttrName(result.name),
+                        builder.getArrayAttr(symbolAttrs));
+  }
+
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return mlir::failure();
+
+  // Set the operandSegmentSizes attribute
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterOperands.size()),
+                           static_cast<int32_t>(privateOperands.size())}));
+
   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
     regionArgs[i].type = argTypes[i];
 
+  // Parse the body region.
+  auto *body = result.addRegion();
   if (parser.parseRegion(*body, regionArgs))
     return mlir::failure();
 
@@ -2706,9 +2762,25 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " -> " << getResultTypes();
     printBlockTerminators = true;
   }
-  p.printOptionalAttrDictWithKeyword(
-      (*this)->getAttrs(),
-      {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
+
+  if (numPrivateBlockArgs() > 0) {
+    p << " private(";
+    llvm::interleaveComma(llvm::zip_equal(getPrivateSymsAttr(),
+                                          getPrivateVars(),
+                                          getRegionPrivateArgs()),
+                          p, [&](auto it) {
+                            p << std::get<0>(it) << " " << std::get<1>(it)
+                              << " -> " << std::get<2>(it);
+                          });
+    p << " : ";
+    llvm::interleaveComma(getPrivateVars(), p,
+                          [&](auto it) { p << it.getType(); });
+    p << ")";
+  }
+
+  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
+                                     {"unordered", "finalValue", "reduceAttrs",
+                                      "operandSegmentSizes", "private_syms"});
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 printBlockTerminators);
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index b09bbf6106dbb..88779e6ebd977 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -32,6 +32,19 @@ using namespace fir;
 using namespace mlir;
 
 namespace {
+/// Looks up from the operation from and returns the PrivateClauseOp with
+/// name symbolName
+///
+/// TODO Copied from OpenMPToLLVMIRTranslation.cpp, move to a shared location.
+/// Maybe a static function on the `PrivateClauseOp`.
+static omp::PrivateClauseOp findPrivatizer(Operation *from,
+                                           SymbolRefAttr symbolName) {
+  omp::PrivateClauseOp privatizer =
+      SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
+                                                                 symbolName);
+  assert(privatizer && "privatizer not found in the symbol table");
+  return privatizer;
+}
 
 // Conversion of fir control ops to more primitive control-flow.
 //
@@ -57,6 +70,50 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
     auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
         rewriter.getContext(), flags);
 
+    // Handle privatization
+    if (!loop.getPrivateVars().empty()) {
+      mlir::OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&loop.getRegion().front());
+      std::optional<ArrayAttr> privateSyms = loop.getPrivateSyms();
+
+      for (auto [privateVar, privateArg, privatizerSym] :
+           llvm::zip_equal(loop.getPrivateVars(), loop.getRegionPrivateArgs(),
+                           *privateSyms)) {
+        SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privatizerSym);
+        omp::PrivateClauseOp privatizer = findPrivatizer(loop, privatizerName);
+
+        mlir::Value localAlloc =
+            rewriter.create<fir::AllocaOp>(loop.getLoc(), privatizer.getType());
+
+        if (privatizer.getDataSharingType() ==
+            omp::DataSharingClauseType::FirstPrivate) {
+          mlir::Block *beforeLocalInit = rewriter.getInsertionBlock();
+          mlir::Block *afterLocalInit = rewriter.splitBlock(
+              rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+          rewriter.cloneRegionBefore(privatizer.getCopyRegion(),
+                                     afterLocalInit);
+          mlir::Block* copyRegionFront = beforeLocalInit->getNextNode();
+          mlir::Block* copyRegionBack = afterLocalInit->getPrevNode();
+
+          rewriter.setInsertionPoint(beforeLocalInit, beforeLocalInit->end());
+          rewriter.create<mlir::cf::BranchOp>(
+              loc, copyRegionFront,
+              llvm::SmallVector<mlir::Value>{privateVar, privateArg});
+
+          rewriter.eraseOp(copyRegionBack->getTerminator());
+          rewriter.setInsertionPoint(copyRegionBack, copyRegionBack->end());
+          rewriter.create<mlir::cf::BranchOp>(loc, afterLocalInit);
+        }
+
+        rewriter.replaceAllUsesWith(privateArg, localAlloc);
+      }
+
+      loop.getRegion().front().eraseArguments(1 + loop.getNumRegionIterArgs(),
+                                              loop.numPrivateBlockArgs());
+      loop.getPrivateVarsMutable().clear();
+      loop.setPrivateSymsAttr(nullptr);
+    }
+
     // Create the start and end blocks that will wrap the DoLoopOp with an
     // initalizer and an end point
     auto *initBlock = rewriter.getInsertionBlock();

Copy link

github-actions bot commented Feb 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ergawy ergawy force-pushed the users/ergawy/locality_specs_4_lower_local_init_spec branch 2 times, most recently from 7dc7e07 to 0005b5e Compare February 21, 2025 10:59
@ergawy ergawy force-pushed the users/ergawy/locality_specs_4_lower_local_init_spec branch from 0005b5e to a7e151b Compare March 12, 2025 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants