diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index dd21afec6eb453..11649ef2d03370 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -119,6 +119,24 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]> CArg<"TypeAttr">:$type)> ]; + let extraClassDeclaration = [{ + BlockArgument getAllocMoldArg() { + return getAllocRegion().getArgument(0); + } + BlockArgument getCopyMoldArg() { + auto ®ion = getCopyRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } + BlockArgument getCopyPrivateArg() { + auto ®ion = getCopyRegion(); + return region.empty() ? nullptr : region.getArgument(1); + } + BlockArgument getDeallocMoldArg() { + auto ®ion = getDeallocRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } + }]; + let hasVerifier = 1; } @@ -1601,22 +1619,41 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, "( `cleanup` $cleanupRegion^ )? "; let extraClassDeclaration = [{ + BlockArgument getAllocMoldArg() { + auto ®ion = getAllocRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } + BlockArgument getInitializerMoldArg() { + return getInitializerRegion().getArgument(0); + } + BlockArgument getInitializerAllocArg() { + return getAllocRegion().empty() ? + nullptr : getInitializerRegion().getArgument(1); + } + BlockArgument getReductionLhsArg() { + return getReductionRegion().getArgument(0); + } + BlockArgument getReductionRhsArg() { + return getReductionRegion().getArgument(1); + } + BlockArgument getAtomicReductionLhsArg() { + auto ®ion = getAtomicReductionRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } + BlockArgument getAtomicReductionRhsArg() { + auto ®ion = getAtomicReductionRegion(); + return region.empty() ? nullptr : region.getArgument(1); + } + BlockArgument getCleanupAllocArg() { + auto ®ion = getCleanupRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } + PointerLikeType getAccumulatorType() { if (getAtomicReductionRegion().empty()) return {}; - return cast(getAtomicReductionRegion().front().getArgument(0).getType()); - } - - Value getInitializerMoldArg() { - return getInitializerRegion().front().getArgument(0); - } - - Value getInitializerAllocArg() { - if (getAllocRegion().empty() || - getInitializerRegion().front().getNumArguments() != 2) - return {nullptr}; - return getInitializerRegion().front().getArgument(1); + return cast(getAtomicReductionLhsArg().getType()); } }]; let hasRegionVerifier = 1; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 19d80fbbd699b0..816c7ff9509d27 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -480,12 +480,11 @@ makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Value *lhs, llvm::Value *rhs, llvm::Value *&result) mutable { - Region &reductionRegion = decl.getReductionRegion(); - moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs); - moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs); + moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs); + moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs); builder.restoreIP(insertPoint); SmallVector phis; - if (failed(inlineConvertOmpRegions(reductionRegion, + if (failed(inlineConvertOmpRegions(decl.getReductionRegion(), "omp.reduction.nonatomic.body", builder, moduleTranslation, &phis))) return llvm::OpenMPIRBuilder::InsertPointTy(); @@ -513,12 +512,11 @@ makeAtomicReductionGen(omp::DeclareReductionOp decl, OwningAtomicReductionGen atomicGen = [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *, llvm::Value *lhs, llvm::Value *rhs) mutable { - Region &atomicRegion = decl.getAtomicReductionRegion(); - moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs); - moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs); + moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs); + moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs); builder.restoreIP(insertPoint); SmallVector phis; - if (failed(inlineConvertOmpRegions(atomicRegion, + if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(), "omp.reduction.atomic.body", builder, moduleTranslation, &phis))) return llvm::OpenMPIRBuilder::InsertPointTy(); @@ -1674,9 +1672,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // argument of the `alloc` region and the second argument of the `copy` // region to be the yielded value of the `alloc` region (this is the // private clone of the privatized value). - copyCloneBuilder.mergeBlocks( - &*newCopyRegionFrontBlock, &*oldAllocBackBlock, - {allocRegion.getArgument(0), oldAllocYieldOp.getOperand(0)}); + copyCloneBuilder.mergeBlocks(&*newCopyRegionFrontBlock, + &*oldAllocBackBlock, + {mlirPrivatizerClone.getAllocMoldArg(), + oldAllocYieldOp.getOperand(0)}); // 4. The old terminator of the `alloc` region is not needed anymore, so // delete it. @@ -1686,8 +1685,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Replace the privatizer block argument with mlir value being privatized. // This way, the body of the privatizer will be changed from using the // region/block argument to the value being privatized. - auto allocRegionArg = allocRegion.getArgument(0); - replaceAllUsesInRegionWith(allocRegionArg, mlirPrivVar, allocRegion); + replaceAllUsesInRegionWith(mlirPrivatizerClone.getAllocMoldArg(), + mlirPrivVar, allocRegion); auto oldIP = builder.saveIP(); builder.restoreIP(allocaIP); @@ -3480,10 +3479,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, " private allocatables is not supported yet"); bodyGenStatus = failure(); } else { - Region &allocRegion = privatizer.getAllocRegion(); - BlockArgument allocRegionArg = allocRegion.getArgument(0); - moduleTranslation.mapValue(allocRegionArg, + moduleTranslation.mapValue(privatizer.getAllocMoldArg(), moduleTranslation.lookupValue(privVar)); + Region &allocRegion = privatizer.getAllocRegion(); SmallVector yieldedValues; if (failed(inlineConvertOmpRegions( allocRegion, "omp.targetop.privatizer", builder,