Skip to content

[WIP] Implement workdistribute construct #140523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

skc7
Copy link
Contributor

@skc7 skc7 commented May 19, 2025

Note: This is very early work in progress PR implementing workdistribute in flang. More changes/commits incoming.

@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from 6e8010d to df65bd5 Compare May 19, 2025 14:18
@skc7 skc7 requested a review from mjklemm May 20, 2025 07:12
@skc7 skc7 marked this pull request as draft May 20, 2025 07:16
@skc7 skc7 self-assigned this May 20, 2025
Copy link

github-actions bot commented Jun 5, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp flang/include/flang/Semantics/openmp-directive-sets.h flang/lib/Lower/OpenMP/OpenMP.cpp flang/lib/Optimizer/Passes/Pipelines.cpp flang/lib/Parser/openmp-parsers.cpp flang/lib/Semantics/resolve-directives.cpp
View the diff from clang-format here.
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 8f2de92cf..30ad180f9 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include <mlir/Dialect/Arith/IR/Arith.h>
 #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
 #include <mlir/Dialect/Utils/IndexingUtils.h>
@@ -34,7 +35,6 @@
 #include <mlir/IR/PatternMatch.h>
 #include <mlir/Interfaces/SideEffectInterfaces.h>
 #include <mlir/Support/LLVM.h>
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include <optional>
 #include <variant>
 
@@ -357,11 +357,13 @@ struct SplitTargetResult {
 /// original data region and avoid unnecessary data movement at each of the
 /// subkernels - we split the target region into a target_data{target}
 /// nest where only the outer one moves the data
-std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp, RewriterBase &rewriter) {
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
 
   auto loc = targetOp->getLoc();
   if (targetOp.getMapVars().empty()) {
-    LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
     return std::nullopt;
   }
 
@@ -381,10 +383,9 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp, Rewrite
   auto ifExpr = targetOp.getIfExpr();
   auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
   auto devicePtrVars = targetOp.getIsDevicePtrVars();
-  auto targetDataOp = rewriter.create<omp::TargetDataOp>(loc, device, ifExpr, 
-                                                          mlir::ValueRange{byRefMapInfos},
-                                                          deviceAddrVars,
-                                                          devicePtrVars);
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, mlir::ValueRange{byRefMapInfos}, deviceAddrVars,
+      devicePtrVars);
 
   auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
   rewriter.create<mlir::omp::TerminatorOp>(loc);
@@ -400,10 +401,10 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp, Rewrite
 
   // Erase TargetOp and its MapInfoOps
   rewriter.eraseOp(targetOp);
-  
+
   for (auto mapInfo : MapInfos) {
     auto mapInfoRes = mapInfo.getResult();
-    if (mapInfoRes.getUsers().empty()) 
+    if (mapInfoRes.getUsers().empty())
       rewriter.eraseOp(mapInfo);
   }
   return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
@@ -441,9 +442,9 @@ static Type getPtrTypeForOmp(Type ty) {
     return fir::LLVMPointerType::get(ty);
 }
 
-static TempOmpVar 
-allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) {
-  MLIRContext& ctx = *ty.getContext();
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
   Value alloc;
   Type allocType;
   auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
@@ -453,28 +454,30 @@ allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) {
     allocType = llvmPtrTy;
     alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
     allocType = intTy;
-  }
-  else {
+  } else {
     allocType = ty;
     alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
   }
   auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
     return rewriter.create<omp::MapInfoOp>(
-      loc, alloc.getType(), alloc,
-      TypeAttr::get(allocType),
-      rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), mappingFlags),
-      rewriter.getAttr<omp::VariableCaptureKindAttr>(
-          omp::VariableCaptureKind::ByRef),
-      /*varPtrPtr=*/Value{},
-      /*members=*/SmallVector<Value>{},
-      /*member_index=*/mlir::ArrayAttr{},
-      /*bounds=*/ValueRange(),
-      /*mapperId=*/mlir::FlatSymbolRefAttr(), 
-      /*name=*/rewriter.getStringAttr(name),
-      rewriter.getBoolAttr(false));
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
   };
-  uint64_t mapFrom = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
-  uint64_t mapTo = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
   auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
   auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
   return TempOmpVar{mapInfoFrom, mapInfoTo};
@@ -496,23 +499,25 @@ static bool usedOutsideSplit(Value v, Operation *split) {
 };
 
 static bool isOpToBeCached(Operation *op) {
-  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {  
-    Value memref = loadOp.getMemref();  
-    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {  
+  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {
+    Value memref = loadOp.getMemref();
+    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {
       // 'op' is an operation within the targetOp that 'splitBefore' is also in.
-      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();  
-      // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.  
-      // This implies the load is from a variable directly mapped into the target region.  
-      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&  
-          !parentOpOfLoadBlock->getRegions().empty()) {  
-        Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front();  
-        if (blockArg.getOwner() == targetOpEntryBlock) {  
-          // This load is from a direct argument of the target op.  
+      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();
+      // Ensure the blockArg belongs to the entry block of this parent
+      // omp.TargetOp. This implies the load is from a variable directly mapped
+      // into the target region.
+      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&
+          !parentOpOfLoadBlock->getRegions().empty()) {
+        Block *targetOpEntryBlock =
+            &parentOpOfLoadBlock->getRegions().front().front();
+        if (blockArg.getOwner() == targetOpEntryBlock) {
+          // This load is from a direct argument of the target op.
           // It's safe to recompute.
-          return false;  
-        }  
-      }  
-    }  
+          return false;
+        }
+      }
+    }
   }
   return true;
 }
@@ -521,24 +526,26 @@ static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
   if (isa<fir::DeclareOp>(op))
     return true;
 
-  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {  
-    Value memref = loadOp.getMemref();  
-    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {  
+  if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {
+    Value memref = loadOp.getMemref();
+    if (auto blockArg = dyn_cast<BlockArgument>(memref)) {
       // 'op' is an operation within the targetOp that 'splitBefore' is also in.
-      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();  
-      // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.  
-      // This implies the load is from a variable directly mapped into the target region.  
-      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&  
-          !parentOpOfLoadBlock->getRegions().empty()) {  
-        Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front();  
-        if (blockArg.getOwner() == targetOpEntryBlock) {  
-          // This load is from a direct argument of the target op.  
+      Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();
+      // Ensure the blockArg belongs to the entry block of this parent
+      // omp.TargetOp. This implies the load is from a variable directly mapped
+      // into the target region.
+      if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&
+          !parentOpOfLoadBlock->getRegions().empty()) {
+        Block *targetOpEntryBlock =
+            &parentOpOfLoadBlock->getRegions().front().front();
+        if (blockArg.getOwner() == targetOpEntryBlock) {
+          // This load is from a direct argument of the target op.
           // It's safe to recompute.
-          return true;  
-        }  
-      }  
-    }  
-  } 
+          return true;
+        }
+      }
+    }
+  }
 
   llvm::SmallVector<MemoryEffects::EffectInstance> effects;
   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
@@ -557,11 +564,10 @@ struct SplitResult {
   omp::TargetOp postTargetOp;
 };
 
-static void collectNonRecomputableDeps(Value& v,
-                                omp::TargetOp targetOp,
-                                SetVector<Operation *>& nonRecomputable,
-                                SetVector<Operation *>& toCache,
-                                SetVector<Operation *>& toRecompute) {
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+                                       SetVector<Operation *> &nonRecomputable,
+                                       SetVector<Operation *> &toCache,
+                                       SetVector<Operation *> &toRecompute) {
   Operation *op = v.getDefiningOp();
   if (!op) {
     assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
@@ -573,16 +579,16 @@ static void collectNonRecomputableDeps(Value& v,
   }
   toRecompute.insert(op);
   for (auto opr : op->getOperands())
-    collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, toRecompute);
+    collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+                               toRecompute);
 }
 
-
 static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
-                        MLIRContext& ctx,
-                        IRMapping &mapping, Operation *splitBefore,
-                        Block *targetBlock, Block *newTargetBlock,
-                        SmallVector<Value>& allocs,
-                        SetVector<Operation *>& toRecompute) {
+                                    MLIRContext &ctx, IRMapping &mapping,
+                                    Operation *splitBefore, Block *targetBlock,
+                                    Block *newTargetBlock,
+                                    SmallVector<Value> &allocs,
+                                    SetVector<Operation *> &toRecompute) {
   for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
     auto originalArg = targetBlock->getArgument(i);
     auto newArg = newTargetBlock->addArgument(originalArg.getType(),
@@ -592,16 +598,17 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
   auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
   for (auto original : allocs) {
     Value newArg = newTargetBlock->addArgument(
-      getPtrTypeForOmp(original.getType()), original.getLoc());
+        getPtrTypeForOmp(original.getType()), original.getLoc());
     Value restored;
     if (isPtr(original.getType())) {
       restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
       if (!isa<LLVM::LLVMPointerType>(original.getType()))
-        restored = rewriter.create<UnrealizedConversionCastOp>(loc, original.getType(), ValueRange(restored))
-                           .getResult(0);
-    } 
-    else {
-        restored = rewriter.create<fir::LoadOp>(loc, newArg);
+        restored = rewriter
+                       .create<UnrealizedConversionCastOp>(
+                           loc, original.getType(), ValueRange(restored))
+                       .getResult(0);
+    } else {
+      restored = rewriter.create<fir::LoadOp>(loc, newArg);
     }
     mapping.map(original, restored);
   }
@@ -612,14 +619,14 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
 }
 
 static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
-                              RewriterBase &rewriter) {
+                             RewriterBase &rewriter) {
   auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
-  MLIRContext& ctx = *targetOp.getContext();
+  MLIRContext &ctx = *targetOp.getContext();
   assert(targetOp);
   auto loc = targetOp.getLoc();
   auto *targetBlock = &targetOp.getRegion().front();
   rewriter.setInsertionPoint(targetOp);
-   
+
   auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
   auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
 
@@ -629,21 +636,24 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
   SetVector<Operation *> nonRecomputable;
   SmallVector<Value> allocs;
 
-  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) {
+  for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+       it++) {
     for (auto res : it->getResults()) {
       if (usedOutsideSplit(res, splitBeforeOp))
         requiredVals.push_back(res);
     }
     if (!isRecomputableAfterFission(&*it, splitBeforeOp))
-        nonRecomputable.insert(&*it);
+      nonRecomputable.insert(&*it);
   }
 
   for (auto requiredVal : requiredVals)
-    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, toRecompute);
-  
+    collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+                               toRecompute);
+
   for (Operation *op : toCache) {
     for (auto res : op->getResults()) {
-      auto alloc = allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+      auto alloc =
+          allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
       allocs.push_back(res);
       preMapOperands.push_back(alloc.from);
       postMapOperands.push_back(alloc.to);
@@ -653,16 +663,16 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
   rewriter.setInsertionPoint(targetOp);
 
   auto preTargetOp = rewriter.create<omp::TargetOp>(
-        targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
-        targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
-        targetOp.getDependVars(), targetOp.getDevice(),
-        targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
-        targetOp.getIfExpr(), targetOp.getInReductionVars(),
-        targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
-        targetOp.getIsDevicePtrVars(), preMapOperands,
-        targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
-        targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
-        targetOp.getPrivateMapsAttr()); 
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
+      targetOp.getPrivateMapsAttr());
   auto *preTargetBlock = rewriter.createBlock(
       &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
   IRMapping preMapping;
@@ -677,15 +687,15 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
 
   auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
 
-
   for (auto original : allocs) {
     Value toStore = preMapping.lookup(original);
     auto newArg = preTargetBlock->addArgument(
         getPtrTypeForOmp(original.getType()), original.getLoc());
     if (isPtr(original.getType())) {
       if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
-        toStore = rewriter.create<UnrealizedConversionCastOp>(loc, llvmPtrTy,
-                                                           ValueRange(toStore))
+        toStore = rewriter
+                      .create<UnrealizedConversionCastOp>(loc, llvmPtrTy,
+                                                          ValueRange(toStore))
                       .getResult(0);
       rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
     } else {
@@ -697,53 +707,52 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
   rewriter.setInsertionPoint(targetOp);
 
   auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
-      targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
-      targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
-      targetOp.getDependVars(), targetOp.getDevice(),
-      targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
-      targetOp.getIfExpr(), targetOp.getInReductionVars(),
-      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
-      targetOp.getIsDevicePtrVars(), postMapOperands,
-      targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
       targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
-      targetOp.getPrivateMapsAttr()); 
+      targetOp.getPrivateMapsAttr());
 
   auto *isolatedTargetBlock =
-        rewriter.createBlock(&isolatedTargetOp.getRegion(),
-                             isolatedTargetOp.getRegion().begin(), {}, {});
+      rewriter.createBlock(&isolatedTargetOp.getRegion(),
+                           isolatedTargetOp.getRegion().begin(), {}, {});
 
   IRMapping isolatedMapping;
   reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
-                          targetBlock, isolatedTargetBlock,
-                          allocs, toRecompute);
+                          targetBlock, isolatedTargetBlock, allocs,
+                          toRecompute);
   rewriter.clone(*splitBeforeOp, isolatedMapping);
   rewriter.create<omp::TerminatorOp>(loc);
 
   omp::TargetOp postTargetOp = nullptr;
-  
+
   if (splitAfter) {
-      rewriter.setInsertionPoint(targetOp);
+    rewriter.setInsertionPoint(targetOp);
     postTargetOp = rewriter.create<omp::TargetOp>(
-        targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
-        targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
-        targetOp.getDependVars(), targetOp.getDevice(),
-        targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
-        targetOp.getIfExpr(), targetOp.getInReductionVars(),
-        targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
-        targetOp.getIsDevicePtrVars(), postMapOperands,
-        targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+        targetOp.getLoc(), targetOp.getAllocateVars(),
+        targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+        targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+        targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+        targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+        targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+        targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+        postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
         targetOp.getPrivateSymsAttr(), targetOp.getThreadLimit(),
-        targetOp.getPrivateMapsAttr()); 
+        targetOp.getPrivateMapsAttr());
     auto *postTargetBlock = rewriter.createBlock(
-          &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+        &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
     IRMapping postMapping;
-    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, 
-                            targetBlock, postTargetBlock,
-                            allocs, toRecompute);
+    reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
+                            targetBlock, postTargetBlock, allocs, toRecompute);
 
     assert(splitBeforeOp->getNumResults() == 0 ||
-             llvm::all_of(splitBeforeOp->getResults(),
-                          [](Value result) { return result.use_empty(); }));
+           llvm::all_of(splitBeforeOp->getResults(),
+                        [](Value result) { return result.use_empty(); }));
 
     for (auto it = std::next(splitBeforeOp->getIterator());
          it != targetBlock->end(); it++)
@@ -780,13 +789,13 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
       runtimeCall = cast<fir::CallOp>(op);
 
     if (allocOp || freeOp || runtimeCall)
-        continue;
+      continue;
     opsToMove.push_back(op);
   }
   // Move ops before targetOp and erase from region
   for (Operation *op : opsToMove)
     rewriter.clone(*op, mapping);
-  
+
   rewriter.eraseOp(targetOp);
 }
 
@@ -794,7 +803,7 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
   auto tuple = getNestedOpToIsolate(targetOp);
   if (!tuple) {
     LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
-    //moveToHost(targetOp, rewriter);
+    // moveToHost(targetOp, rewriter);
     return;
   }
 
@@ -804,13 +813,13 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
 
   if (splitBefore && splitAfter) {
     auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    //moveToHost(res.preTargetOp, rewriter);
+    // moveToHost(res.preTargetOp, rewriter);
     fissionTarget(res.postTargetOp, rewriter);
     return;
   }
   if (splitBefore) {
     auto res = isolateOp(toIsolate, splitAfter, rewriter);
-    //moveToHost(res.preTargetOp, rewriter);
+    // moveToHost(res.preTargetOp, rewriter);
     return;
   }
   if (splitAfter) {
@@ -853,10 +862,10 @@ public:
       IRRewriter rewriter(&context);
       for (auto targetOp : targetOps) {
         auto res = splitTargetData(targetOp, rewriter);
-        if (res) fissionTarget(res->targetOp, rewriter);
+        if (res)
+          fissionTarget(res->targetOp, rewriter);
       }
     }
-
   }
 };
 } // namespace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants