diff --git a/flang/include/flang/Lower/OpenMP/Utils.h b/flang/include/flang/Lower/OpenMP/Utils.h index 800683a464b1a1..39d7c9ea79c812 100644 --- a/flang/include/flang/Lower/OpenMP/Utils.h +++ b/flang/include/flang/Lower/OpenMP/Utils.h @@ -153,9 +153,7 @@ void insertChildMapInfoIntoParent( Fortran::lower::StatementContext &stmtCtx, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymbols); + llvm::SmallVectorImpl &mapSymbols); mlir::Type getLoopVarType(lower::AbstractConverter &converter, std::size_t loopVarTypeSize); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 148aa68f74feaf..e36295a3804170 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter, static void addUseDeviceClause( lower::AbstractConverter &converter, const omp::ObjectList &objects, llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) { genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) { + for (mlir::Value &operand : operands) checkMapType(operand.getLoc(), operand.getType()); - useDeviceTypes.push_back(operand.getType()); - useDeviceLocs.push_back(operand.getLoc()); - } + for (const omp::Object &object : objects) useDeviceSyms.push_back(object.sym()); } @@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { bool ClauseProcessor::processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const { + llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::HasDeviceAddr &devAddrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -864,14 +858,12 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const { + llvm::SmallVectorImpl &isDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::IsDevicePtr &devPtrClause, const parser::CharBlock &) { addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceTypes, isDeviceLocs, isDeviceSymbols); + isDeviceSyms); }); } @@ -891,9 +883,7 @@ void ClauseProcessor::processMapObjects( llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymTypes) const { + llvm::SmallVectorImpl &mapSyms) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); for (const omp::Object &object : objects) { @@ -948,11 +938,7 @@ void ClauseProcessor::processMapObjects( object, parentMemberIndices[parentObj.value()], mapOp, semaCtx); } else { mapVars.push_back(mapOp); - mapSyms->push_back(object.sym()); - if (mapSymTypes) - mapSymTypes->push_back(baseOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(baseOp.getLoc()); + mapSyms.push_back(object.sym()); } } } @@ -960,9 +946,7 @@ void ClauseProcessor::processMapObjects( bool ClauseProcessor::processMap( mlir::Location currentLocation, lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymTypes) const { + llvm::SmallVectorImpl *mapSyms) const { // We always require tracking of symbols, even if the caller does not, // so we create an optionally used local set of symbols when the mapSyms // argument is not present. @@ -1018,13 +1002,11 @@ bool ClauseProcessor::processMap( processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, - parentMemberIndices, result.mapVars, ptrMapSyms, - mapSymLocs, mapSymTypes); + parentMemberIndices, result.mapVars, *ptrMapSyms); }); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, - result.mapVars, mapSymTypes, mapSymLocs, - ptrMapSyms); + result.mapVars, *ptrMapSyms); return clauseFound; } @@ -1044,7 +1026,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, parentMemberIndices, result.mapVars, - &mapSymbols); + mapSymbols); }; bool clauseFound = findRepeatableClause(callbackFn); @@ -1053,7 +1035,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, insertChildMapInfoIntoParent( converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars, - /*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols); + mapSymbols); return clauseFound; } @@ -1071,8 +1053,7 @@ bool ClauseProcessor::processNontemporal( bool ClauseProcessor::processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl *outReductionTypes, - llvm::SmallVectorImpl *outReductionSyms) const { + llvm::SmallVectorImpl &outReductionSyms) const { return findRepeatableClause( [&](const omp::clause::Reduction &clause, const parser::CharBlock &) { llvm::SmallVector reductionVars; @@ -1080,25 +1061,16 @@ bool ClauseProcessor::processReduction( llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction( - currentLocation, converter, clause, reductionVars, reduceVarByRef, - reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr); + rp.addDeclareReduction(currentLocation, converter, clause, + reductionVars, reduceVarByRef, + reductionDeclSymbols, reductionSyms); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); llvm::copy(reductionDeclSymbols, std::back_inserter(result.reductionSyms)); - - if (outReductionTypes) { - outReductionTypes->reserve(outReductionTypes->size() + - reductionVars.size()); - llvm::transform(reductionVars, std::back_inserter(*outReductionTypes), - [](mlir::Value v) { return v.getType(); }); - } - - if (outReductionSyms) - llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms)); + llvm::copy(reductionSyms, std::back_inserter(outReductionSyms)); }); } @@ -1124,8 +1096,6 @@ bool ClauseProcessor::processEnter( bool ClauseProcessor::processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const { std::map parentMemberIndices; bool clauseFound = findRepeatableClause( @@ -1137,19 +1107,16 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, - result.useDeviceAddrVars, &useDeviceTypes, - &useDeviceLocs, &useDeviceSyms); + result.useDeviceAddrVars, useDeviceSyms); return clauseFound; } bool ClauseProcessor::processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const { std::map parentMemberIndices; @@ -1162,12 +1129,11 @@ bool ClauseProcessor::processUseDevicePtr( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, - &useDeviceSyms, &useDeviceLocs, &useDeviceTypes); + useDeviceSyms); }); insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, - result.useDevicePtrVars, &useDeviceTypes, - &useDeviceLocs, &useDeviceSyms); + result.useDevicePtrVars, useDeviceSyms); return clauseFound; } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index aeace91084c44c..b54c7796499ec5 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -68,9 +68,7 @@ class ClauseProcessor { mlir::omp::FinalClauseOps &result) const; bool processHasDeviceAddr( mlir::omp::HasDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const; + llvm::SmallVectorImpl &isDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; @@ -104,43 +102,33 @@ class ClauseProcessor { mlir::omp::IfClauseOps &result) const; bool processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, - llvm::SmallVectorImpl &isDeviceTypes, - llvm::SmallVectorImpl &isDeviceLocs, - llvm::SmallVectorImpl &isDeviceSymbols) const; + llvm::SmallVectorImpl &isDeviceSyms) const; bool processLink(llvm::SmallVectorImpl &result) const; // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to - // store the original type, location and Fortran symbol for the map operands. - // They may be used later on to create the block_arguments for some of the - // target directives that require it. - bool processMap( - mlir::Location currentLocation, lower::StatementContext &stmtCtx, - mlir::omp::MapClauseOps &result, - llvm::SmallVectorImpl *mapSyms = nullptr, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + // The optional parameter mapSyms is used to store the original Fortran symbol + // for the map operands. It may be used later on to create the block_arguments + // for some of the directives that require it. + bool processMap(mlir::Location currentLocation, + lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms = + nullptr) const; bool processMotionClauses(lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result); bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const; bool processReduction( mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, - llvm::SmallVectorImpl *reductionTypes = nullptr, - llvm::SmallVectorImpl *reductionSyms = - nullptr) const; + llvm::SmallVectorImpl &reductionSyms) const; bool processTo(llvm::SmallVectorImpl &result) const; bool processUseDeviceAddr( lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const; bool processUseDevicePtr( lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) const; // Call this method for these clauses that should be supported but are not @@ -180,9 +168,7 @@ class ClauseProcessor { llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl *mapSyms, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + llvm::SmallVectorImpl &mapSyms) const; lower::AbstractConverter &converter; semantics::SemanticsContext &semaCtx; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index f086c9dbd09656..062aa391dff0c3 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -46,6 +46,40 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +namespace { +/// Structure holding the information needed to create and bind entry block +/// arguments associated to a single clause. +struct EntryBlockArgsEntry { + llvm::ArrayRef syms; + llvm::ArrayRef vars; + + bool isValid() const { + // This check allows specifying a smaller number of symbols than values + // because in some case cases a single symbol generates multiple block + // arguments. + return syms.size() <= vars.size(); + } +}; + +/// Structure holding the information needed to create and bind entry block +/// arguments associated to all clauses that can define them. +struct EntryBlockArgs { + EntryBlockArgsEntry inReduction; + EntryBlockArgsEntry map; + EntryBlockArgsEntry priv; + EntryBlockArgsEntry reduction; + EntryBlockArgsEntry taskReduction; + EntryBlockArgsEntry useDeviceAddr; + EntryBlockArgsEntry useDevicePtr; + + bool isValid() const { + return inReduction.isValid() && map.isValid() && priv.isValid() && + reduction.isValid() && taskReduction.isValid() && + useDeviceAddr.isValid() && useDevicePtr.isValid(); + } +}; +} // namespace + /// Get the directive enumeration value corresponding to the given OpenMP /// construct PFT node. llvm::omp::Directive @@ -221,6 +255,164 @@ static void genOMPDispatch(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item); +/// Bind symbols to their corresponding entry block arguments. +/// +/// The binding will be performed inside of the current block, which does not +/// necessarily have to be part of the operation for which the binding is done. +/// However, block arguments must be accessible. This enables controlling the +/// insertion point of any new MLIR operations related to the binding of +/// arguments of a loop wrapper operation. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] op - owner operation of the block arguments to bind. +/// \param [in] args - entry block arguments information for the given +/// operation. +static void bindEntryBlockArgs(lower::AbstractConverter &converter, + mlir::omp::BlockArgOpenMPOpInterface op, + const EntryBlockArgs &args) { + assert(op != nullptr && "invalid block argument-defining operation"); + assert(args.isValid() && "invalid args"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + auto bindSingleMapLike = [&converter, + &firOpBuilder](const semantics::Symbol &sym, + const mlir::BlockArgument &arg) { + // Clones the `bounds` placing them inside the entry block and returns + // them. + auto cloneBound = [&](mlir::Value bound) { + if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { + mlir::Operation *clonedOp = firOpBuilder.clone(*bound.getDefiningOp()); + return clonedOp->getResult(0); + } + TODO(converter.getCurrentLocation(), + "target map-like clause operand unsupported bound type"); + }; + + auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { + llvm::SmallVector clonedBounds; + llvm::transform(bounds, std::back_inserter(clonedBounds), + [&](mlir::Value bound) { return cloneBound(bound); }); + return clonedBounds; + }; + + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(sym); + auto refType = mlir::dyn_cast(arg.getType()); + if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { + converter.bindSymbol(sym, arg); + } else { + extVal.match( + [&](const fir::BoxValue &v) { + converter.bindSymbol(sym, + fir::BoxValue(arg, cloneBounds(v.getLBounds()), + v.getExplicitParameters(), + v.getExplicitExtents())); + }, + [&](const fir::MutableBoxValue &v) { + converter.bindSymbol( + sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), + v.getMutableProperties())); + }, + [&](const fir::ArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()), + v.getSourceBox())); + }, + [&](const fir::CharArrayBoxValue &v) { + converter.bindSymbol( + sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), + cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()))); + }, + [&](const fir::CharBoxValue &v) { + converter.bindSymbol( + sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); + }, + [&](const fir::UnboxedValue &v) { converter.bindSymbol(sym, arg); }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "target map clause operand unsupported type"); + }); + } + }; + + auto bindMapLike = + [&bindSingleMapLike](llvm::ArrayRef syms, + llvm::ArrayRef args) { + // Structure component symbols don't have bindings, and can only be + // explicitly mapped individually. If a member is captured implicitly + // we map the entirety of the derived type when we find its symbol. + llvm::SmallVector processedSyms; + llvm::copy_if(syms, std::back_inserter(processedSyms), + [](auto *sym) { return !sym->owner().IsDerivedType(); }); + + for (auto [sym, arg] : llvm::zip_equal(processedSyms, args)) + bindSingleMapLike(*sym, arg); + }; + + auto bindPrivateLike = [&converter, &firOpBuilder]( + llvm::ArrayRef syms, + llvm::ArrayRef vars, + llvm::ArrayRef args) { + llvm::SmallVector processedSyms; + for (auto *sym : syms) { + if (const auto *commonDet = + sym->detailsIf()) { + llvm::transform(commonDet->objects(), std::back_inserter(processedSyms), + [&](const auto &mem) { return &*mem; }); + } else { + processedSyms.push_back(sym); + } + } + + for (auto [sym, var, arg] : llvm::zip_equal(processedSyms, vars, args)) + converter.bindSymbol( + *sym, + hlfir::translateToExtendedValue( + var.getLoc(), firOpBuilder, hlfir::Entity{arg}, + /*contiguousHint=*/ + evaluate::IsSimplyContiguous(*sym, converter.getFoldingContext())) + .first); + }; + + // Process in clause name alphabetical order to match block arguments order. + bindPrivateLike(args.inReduction.syms, args.inReduction.vars, + op.getInReductionBlockArgs()); + bindMapLike(args.map.syms, op.getMapBlockArgs()); + bindPrivateLike(args.priv.syms, args.priv.vars, op.getPrivateBlockArgs()); + bindPrivateLike(args.reduction.syms, args.reduction.vars, + op.getReductionBlockArgs()); + bindPrivateLike(args.taskReduction.syms, args.taskReduction.vars, + op.getTaskReductionBlockArgs()); + bindMapLike(args.useDeviceAddr.syms, op.getUseDeviceAddrBlockArgs()); + bindMapLike(args.useDevicePtr.syms, op.getUseDevicePtrBlockArgs()); +} + +/// Get the list of base values that the specified map-like variables point to. +/// +/// This function must be kept in sync with changes to the `createMapInfoOp` +/// utility function, since it must take into account the potential introduction +/// of levels of indirection (i.e. intermediate ops). +/// +/// \param [in] vars - list of values passed to map-like clauses, returned +/// by an `omp.map.info` operation. +/// \param [out] baseOps - populated with the `var_ptr` values of the +/// corresponding defining operations. +static void +extractMappedBaseValues(llvm::ArrayRef vars, + llvm::SmallVectorImpl &baseOps) { + llvm::transform(vars, std::back_inserter(baseOps), [](mlir::Value map) { + auto mapInfo = map.getDefiningOp(); + assert(mapInfo && "expected all map vars to be defined by omp.map.info"); + + mlir::Value varPtr = mapInfo.getVarPtr(); + if (auto boxAddr = varPtr.getDefiningOp()) + return boxAddr.getVal(); + + return varPtr; + }); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -597,55 +789,41 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, return storeOp; } -// This helper function implements the functionality of "promoting" -// non-CPTR arguments of use_device_ptr to use_device_addr -// arguments (automagic conversion of use_device_ptr -> -// use_device_addr in these cases). The way we do so currently is -// through the shuffling of operands from the devicePtrOperands to -// deviceAddrOperands where neccesary and re-organizing the types, -// locations and symbols to maintain the correct ordering of ptr/addr -// input -> BlockArg. +// This helper function implements the functionality of "promoting" non-CPTR +// arguments of use_device_ptr to use_device_addr arguments (automagic +// conversion of use_device_ptr -> use_device_addr in these cases). The way we +// do so currently is through the shuffling of operands from the +// devicePtrOperands to deviceAddrOperands, as well as the types, locations and +// symbols. // -// This effectively implements some deprecated OpenMP functionality -// that some legacy applications unfortunately depend on -// (deprecated in specification version 5.2): +// This effectively implements some deprecated OpenMP functionality that some +// legacy applications unfortunately depend on (deprecated in specification +// version 5.2): // -// "If a list item in a use_device_ptr clause is not of type C_PTR, -// the behavior is as if the list item appeared in a use_device_addr -// clause. Support for such list items in a use_device_ptr clause -// is deprecated." +// "If a list item in a use_device_ptr clause is not of type C_PTR, the behavior +// is as if the list item appeared in a use_device_addr clause. Support for +// such list items in a use_device_ptr clause is deprecated." static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( llvm::SmallVectorImpl &useDeviceAddrVars, + llvm::SmallVectorImpl &useDeviceAddrSyms, llvm::SmallVectorImpl &useDevicePtrVars, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) { - auto moveElementToBack = [](size_t idx, auto &vector) { - auto *iter = std::next(vector.begin(), idx); - vector.push_back(*iter); - vector.erase(iter); - }; - + llvm::SmallVectorImpl &useDevicePtrSyms) { // Iterate over our use_device_ptr list and shift all non-cptr arguments into // use_device_addr. - for (auto *it = useDevicePtrVars.begin(); it != useDevicePtrVars.end();) { - if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - useDeviceAddrVars.push_back(*it); - // We have to shuffle the symbols around as well, to maintain - // the correct Input -> BlockArg for use_device_ptr/use_device_addr. - // NOTE: However, as map's do not seem to be included currently - // this isn't as pertinent, but we must try to maintain for - // future alterations. I believe the reason they are not currently - // is that the BlockArg assign/lowering needs to be extended - // to a greater set of types. - auto idx = std::distance(useDevicePtrVars.begin(), it); - moveElementToBack(idx, useDeviceTypes); - moveElementToBack(idx, useDeviceLocs); - moveElementToBack(idx, useDeviceSymbols); - it = useDevicePtrVars.erase(it); + auto *varIt = useDevicePtrVars.begin(); + auto *symIt = useDevicePtrSyms.begin(); + while (varIt != useDevicePtrVars.end()) { + if (fir::isa_builtin_cptr_type(fir::unwrapRefType(varIt->getType()))) { + ++varIt; + ++symIt; continue; } - ++it; + + useDeviceAddrVars.push_back(*varIt); + useDeviceAddrSyms.push_back(*symIt); + + varIt = useDevicePtrVars.erase(varIt); + symIt = useDevicePtrSyms.erase(symIt); } } @@ -751,14 +929,14 @@ getDeclareTargetFunctionDevice( /// \param [in] converter - PFT to MLIR conversion interface. /// \param [in] loc - location. /// \param [in] args - symbols of induction variables. -/// \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper +/// \param [in] wrapperArgs - list of parent loop wrappers and their associated /// entry block arguments. -/// \param [in] wrapperArgs - entry block arguments of parent loop wrappers. -static void -genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, llvm::ArrayRef args, - llvm::ArrayRef wrapperSyms = {}, - llvm::ArrayRef wrapperArgs = {}) { +static void genLoopVars( + mlir::Operation *op, lower::AbstractConverter &converter, + mlir::Location &loc, llvm::ArrayRef args, + llvm::ArrayRef< + std::pair> + wrapperArgs = {}) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto ®ion = op->getRegion(0); @@ -772,8 +950,8 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, // Bind the entry block arguments of parent wrappers to the corresponding // symbols. - for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs)) - converter.bindSymbol(*arg, prv); + for (auto [argGeneratingOp, args] : wrapperArgs) + bindEntryBlockArgs(converter, argGeneratingOp, args); // The argument is not currently in memory, so make a temporary for the // argument, and store it there, then bind that location to the argument. @@ -786,21 +964,47 @@ genLoopVars(mlir::Operation *op, lower::AbstractConverter &converter, firOpBuilder.setInsertionPointAfter(storeOp); } -static void -genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef reductionArgs, - llvm::ArrayRef reductionTypes) { +/// Create an entry block for the given region, including the clause-defined +/// arguments specified. +/// +/// \param [in] converter - PFT to MLIR conversion interface. +/// \param [in] args - entry block arguments information for the given +/// operation. +/// \param [in] region - Empty region in which to create the entry block. +static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, + const EntryBlockArgs &args, + mlir::Region ®ion) { + assert(args.isValid() && "invalid args"); + assert(region.empty() && "non-empty region"); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - llvm::SmallVector blockArgLocs(reductionArgs.size(), loc); - - mlir::Block *entryBlock = firOpBuilder.createBlock( - &op->getRegion(0), {}, reductionTypes, blockArgLocs); - // Bind the reduction arguments to their block arguments. - for (auto [arg, prv] : - llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { - converter.bindSymbol(*arg, prv); - } + + llvm::SmallVector types; + llvm::SmallVector locs; + unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + + args.priv.vars.size() + args.reduction.vars.size() + + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size(); + types.reserve(numVars); + locs.reserve(numVars); + + auto extractTypeLoc = [&types, &locs](llvm::ArrayRef vals) { + llvm::transform(vals, std::back_inserter(types), + [](mlir::Value v) { return v.getType(); }); + llvm::transform(vals, std::back_inserter(locs), + [](mlir::Value v) { return v.getLoc(); }); + }; + + // Populate block arguments in clause name alphabetical order to match + // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.inReduction.vars); + extractTypeLoc(args.map.vars); + extractTypeLoc(args.priv.vars); + extractTypeLoc(args.reduction.vars); + extractTypeLoc(args.taskReduction.vars); + extractTypeLoc(args.useDeviceAddr.vars); + extractTypeLoc(args.useDevicePtr.vars); + + return firOpBuilder.createBlock(®ion, {}, types, locs); } static void @@ -828,42 +1032,6 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter, declareTargetOp.setDeclareTarget(deviceType, captureClause); } -/// For an operation that takes `omp.private` values as region args, this util -/// merges the private vars info into the region arguments list. -/// -/// \tparam OMPOP - the OpenMP op that takes `omp.private` inputs. -/// \tparam InfoTy - the type of private info we want to merge; e.g. mlir::Type -/// or mlir::Location fields of the private var list. -/// -/// \param [in] op - the op accepting `omp.private` inputs. -/// \param [in] currentList - the current list of region info that we -/// want to merge private info with. For example this could be the list of types -/// or locations of previous arguments to \op's region. -/// \param [in] infoAccessor - for a private variable, this returns the -/// data we want to merge: type or location. -/// \param [out] allRegionArgsInfo - the merged list of region info. -/// \param [in] addBeforePrivate - `true` if the passed information goes before -/// private information. -template -static void -mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef currentList, - llvm::function_ref infoAccessor, - llvm::SmallVectorImpl &allRegionArgsInfo, - bool addBeforePrivate) { - mlir::OperandRange privateVars = op.getPrivateVars(); - - if (addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); - - llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo), - infoAccessor); - - if (!addBeforePrivate) - llvm::transform(currentList, std::back_inserter(allRegionArgsInfo), - [](InfoTy i) { return i; }); -} - //===----------------------------------------------------------------------===// // Op body generation helper structures and functions //===----------------------------------------------------------------------===// @@ -1083,94 +1251,16 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, marker->erase(); } -void mapBodySymbols(lower::AbstractConverter &converter, mlir::Region ®ion, - llvm::ArrayRef mapSyms) { - assert(region.hasOneBlock() && "target must have single region"); - mlir::Block ®ionBlock = region.front(); - // Clones the `bounds` placing them inside the target region and returns them. - auto cloneBound = [&](mlir::Value bound) { - if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { - mlir::Operation *clonedOp = bound.getDefiningOp()->clone(); - regionBlock.push_back(clonedOp); - return clonedOp->getResult(0); - } - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported bound type"); - }; - - auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { - llvm::SmallVector clonedBounds; - for (mlir::Value bound : bounds) - clonedBounds.emplace_back(cloneBound(bound)); - return clonedBounds; - }; - - // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - const mlir::BlockArgument &arg = region.getArgument(argIndex); - // Avoid capture of a reference to a structured binding. - const semantics::Symbol *sym = argSymbol; - // Structure component symbols don't have bindings. - if (sym->owner().IsDerivedType()) - continue; - fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); - auto refType = mlir::dyn_cast(arg.getType()); - if (refType && fir::isa_builtin_cptr_type(refType.getElementType())) { - converter.bindSymbol(*argSymbol, arg); - } else { - extVal.match( - [&](const fir::BoxValue &v) { - converter.bindSymbol(*sym, - fir::BoxValue(arg, cloneBounds(v.getLBounds()), - v.getExplicitParameters(), - v.getExplicitExtents())); - }, - [&](const fir::MutableBoxValue &v) { - converter.bindSymbol( - *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), - v.getMutableProperties())); - }, - [&](const fir::ArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()), - v.getSourceBox())); - }, - [&](const fir::CharArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), - cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()))); - }, - [&](const fir::CharBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharBoxValue(arg, cloneBound(v.getLen()))); - }, - [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported type"); - }); - } - } -} - static void genBodyOfTargetDataOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetDataOp &dataOp, - llvm::ArrayRef useDeviceSymbols, - llvm::ArrayRef useDeviceLocs, - llvm::ArrayRef useDeviceTypes, + mlir::omp::TargetDataOp &dataOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(useDeviceTypes.size() == useDeviceLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = dataOp.getRegion(); - firOpBuilder.createBlock(®ion, {}, useDeviceTypes, useDeviceLocs); - mapBodySymbols(converter, region, useDeviceSymbols); + genEntryBlock(converter, args, dataOp.getRegion()); + bindEntryBlockArgs(converter, dataOp, args); // Insert dummy instruction to remember the insertion position. The // marker will be deleted by clean up passes since there are no uses. @@ -1211,19 +1301,25 @@ static void genBodyOfTargetDataOp( // This is for utilisation with TargetOp. static void genIntermediateCommonBlockAccessors( Fortran::lower::AbstractConverter &converter, - const mlir::Location ¤tLocation, mlir::Region ®ion, + const mlir::Location ¤tLocation, + llvm::ArrayRef mapBlockArgs, llvm::ArrayRef mapSyms) { - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - if (auto *details = - argSymbol->detailsIf()) { - for (auto obj : details->objects()) { - auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( - converter, currentLocation, *obj, region.getArgument(argIndex)); - fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); - fir::ExtendedValue targetCBExv = - getExtendedValue(sexv, targetCBMemberBind); - converter.bindSymbol(*obj, targetCBExv); - } + // Iterate over the symbol list, which will be shorter than the list of + // arguments if new entry block arguments were introduced to implicitly map + // outside values used by the bounds cloned into the target region. In that + // case, the additional block arguments do not need processing here. + for (auto [mapSym, mapArg] : llvm::zip_first(mapSyms, mapBlockArgs)) { + auto *details = mapSym->detailsIf(); + if (!details) + continue; + + for (auto obj : details->objects()) { + auto targetCBMemberBind = Fortran::lower::genCommonBlockMember( + converter, currentLocation, *obj, mapArg); + fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj); + fir::ExtendedValue targetCBExv = + getExtendedValue(sexv, targetCBMemberBind); + converter.bindSymbol(*obj, targetCBExv); } } } @@ -1233,50 +1329,19 @@ static void genIntermediateCommonBlockAccessors( static void genBodyOfTargetOp( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::omp::TargetOp &targetOp, - llvm::ArrayRef mapSyms, - llvm::ArrayRef mapSymLocs, - llvm::ArrayRef mapSymTypes, + mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item, DataSharingProcessor &dsp) { - assert(mapSymTypes.size() == mapSymLocs.size()); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = targetOp.getRegion(); - - llvm::SmallVector allRegionArgTypes; - llvm::SmallVector allRegionArgLocs; - mergePrivateVarsInfo(targetOp, mapSymTypes, - llvm::function_ref{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/true); + auto argIface = llvm::cast(*targetOp); - mergePrivateVarsInfo(targetOp, mapSymLocs, - llvm::function_ref{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/true); - - mlir::Block *regionBlock = firOpBuilder.createBlock( - ®ion, {}, allRegionArgTypes, allRegionArgLocs); + mlir::Region ®ion = targetOp.getRegion(); + mlir::Block *entryBlock = genEntryBlock(converter, args, region); if (!enableDelayedPrivatizationStaging) dsp.processStep2(); - mapBodySymbols(converter, region, mapSyms); - - for (auto [argIndex, argSymbol] : - llvm::enumerate(dsp.getDelayedPrivSymbols())) { - argIndex = mapSyms.size() + argIndex; - - const mlir::BlockArgument &arg = region.getArgument(argIndex); - converter.bindSymbol(*argSymbol, - hlfir::translateToExtendedValue( - currentLocation, firOpBuilder, hlfir::Entity{arg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *argSymbol, converter.getFoldingContext())) - .first); - } + bindEntryBlockArgs(converter, targetOp, args); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1290,12 +1355,12 @@ static void genBodyOfTargetOp( assert(valOp != nullptr); if (mlir::isMemoryEffectFree(valOp)) { mlir::Operation *clonedOp = valOp->clone(); - regionBlock->push_front(clonedOp); + entryBlock->push_front(clonedOp); assert(clonedOp->getNumResults() == 1); - val.replaceUsesWithIf( - clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); + val.replaceUsesWithIf(clonedOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); } else { auto savedIP = firOpBuilder.getInsertionPoint(); firOpBuilder.setInsertionPointAfter(valOp); @@ -1316,18 +1381,23 @@ static void genBodyOfTargetOp( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT), mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType()); + // Get the index of the first non-map argument before modifying mapVars, + // then append an element to mapVars and an associated entry block + // argument at that index. + unsigned insertIndex = + argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs(); targetOp.getMapVarsMutable().append(mapOp); + mlir::Value clonedValArg = region.insertArgument( + insertIndex, copyVal.getType(), copyVal.getLoc()); - mlir::Value clonedValArg = - region.addArgument(copyVal.getType(), copyVal.getLoc()); - firOpBuilder.setInsertionPointToStart(regionBlock); + firOpBuilder.setInsertionPointToStart(entryBlock); auto loadOp = firOpBuilder.create(clonedValArg.getLoc(), clonedValArg); - val.replaceUsesWithIf( - loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); - firOpBuilder.setInsertionPoint(regionBlock, savedIP); + val.replaceUsesWithIf(loadOp->getResult(0), + [entryBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == entryBlock; + }); + firOpBuilder.setInsertionPoint(entryBlock, savedIP); } } valuesDefinedAbove.clear(); @@ -1354,14 +1424,14 @@ static void genBodyOfTargetOp( firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); // If we map a common block using it's symbol e.g. map(tofrom: /common_block/) - // and accessing it's members within the target region, there is a large + // and accessing its members within the target region, there is a large // chance we will end up with uses external to the region accessing the common // resolve these, we do so by generating new common block member accesses // within the region, binding them to the member symbol for the scope of the // region so that subsequent code generation within the region will utilise // our new member accesses we have created. - genIntermediateCommonBlockAccessors(converter, currentLocation, region, - mapSyms); + genIntermediateCommonBlockAccessors( + converter, currentLocation, argIface.getMapBlockArgs(), args.map.syms); if (ConstructQueue::const_iterator next = std::next(item); next != queue.end()) { @@ -1387,7 +1457,7 @@ static OpTy genOpWithBody(const OpWithBodyGenInfo &info, template static OpTy genWrapperOp(lower::AbstractConverter &converter, mlir::Location loc, const ClauseOpsTy &clauseOps, - llvm::ArrayRef blockArgTypes) { + const EntryBlockArgs &args) { static_assert( OpTy::template hasTrait(), "expected a loop wrapper"); @@ -1397,9 +1467,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter, auto op = firOpBuilder.create(loc, clauseOps); // Create entry block with arguments. - llvm::SmallVector locs(blockArgTypes.size(), loc); - firOpBuilder.createBlock(&op.getRegion(), /*insertPt=*/{}, blockArgTypes, - locs); + genEntryBlock(converter, args, op.getRegion()); firOpBuilder.setInsertionPoint( lower::genOpenMPTerminator(firOpBuilder, op, loc)); @@ -1481,7 +1549,6 @@ static void genParallelClauses( mlir::Location loc, bool evalOutsideTarget, mlir::omp::ParallelOperands &clauseOps, mlir::omp::NumThreadsClauseOps &numThreadsClauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1498,31 +1565,31 @@ static void genParallelClauses( } cp.processProcBind(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); } static void genSectionsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List &clauses, mlir::Location loc, mlir::omp::SectionsOperands &clauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processNowait(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); } -static void genSimdClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - const List &clauses, mlir::Location loc, - mlir::omp::SimdOperands &clauseOps) { +static void genSimdClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::SimdOperands &clauseOps, + llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAligned(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); cp.processNontemporal(clauseOps); cp.processOrder(clauseOps); - cp.processReduction(loc, clauseOps); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); @@ -1544,24 +1611,16 @@ static void genTargetClauses( lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, bool processHostOnlyClauses, mlir::omp::TargetOperands &clauseOps, - llvm::SmallVectorImpl &mapSyms, - llvm::SmallVectorImpl &mapLocs, - llvm::SmallVectorImpl &mapTypes, - llvm::SmallVectorImpl &deviceAddrSyms, - llvm::SmallVectorImpl &deviceAddrLocs, - llvm::SmallVectorImpl &deviceAddrTypes, - llvm::SmallVectorImpl &devicePtrSyms, - llvm::SmallVectorImpl &devicePtrLocs, - llvm::SmallVectorImpl &devicePtrTypes) { + llvm::SmallVectorImpl &hasDeviceAddrSyms, + llvm::SmallVectorImpl &isDevicePtrSyms, + llvm::SmallVectorImpl &mapSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); - cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, - deviceAddrSyms); + cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, - devicePtrSyms); - cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes); + cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); if (processHostOnlyClauses) cp.processNowait(clauseOps); @@ -1584,32 +1643,26 @@ static void genTargetDataClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, mlir::omp::TargetDataOperands &clauseOps, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSyms) { + llvm::SmallVectorImpl &useDeviceAddrSyms, + llvm::SmallVectorImpl &useDevicePtrSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDevice(stmtCtx, clauseOps); cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); cp.processMap(loc, stmtCtx, clauseOps); - cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - cp.processUseDevicePtr(stmtCtx, clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); + cp.processUseDeviceAddr(stmtCtx, clauseOps, useDeviceAddrSyms); + cp.processUseDevicePtr(stmtCtx, clauseOps, useDevicePtrSyms); // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy // code may still depend on this functionality, so we should support it // in some manner. We do so currently by simply shifting non-cptr operands - // from the use_device_ptr list into the front of the use_device_addr list - // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and - // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg - // ordering. + // from the use_device_ptr lists into the use_device_addr lists. // TODO: Perhaps create a user provideable compiler option that will // re-introduce a hard-error rather than a warning in these cases. promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - clauseOps.useDeviceAddrVars, clauseOps.useDevicePtrVars, useDeviceTypes, - useDeviceLocs, useDeviceSyms); + clauseOps.useDeviceAddrVars, useDeviceAddrSyms, + clauseOps.useDevicePtrVars, useDevicePtrSyms); } static void genTargetEnterExitUpdateDataClauses( @@ -1674,7 +1727,6 @@ static void genTeamsClauses( mlir::omp::TeamsOperands &clauseOps, mlir::omp::NumTeamsClauseOps &numTeamsClauseOps, mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1693,20 +1745,19 @@ static void genTeamsClauses( cp.processNumTeams(stmtCtx, numTeamsClauseOps); cp.processThreadLimit(stmtCtx, threadLimitClauseOps); } - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); } static void genWsloopClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, mlir::Location loc, mlir::omp::WsloopOperands &clauseOps, - llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processNowait(clauseOps); cp.processOrder(clauseOps); cp.processOrdered(clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); cp.processTODO( @@ -1769,22 +1820,20 @@ genFlushOp(lower::AbstractConverter &converter, lower::SymMap &symTable, converter.getCurrentLocation(), operandRange); } -static mlir::omp::LoopNestOp -genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, ConstructQueue::const_iterator item, - mlir::omp::LoopNestOperands &clauseOps, - llvm::ArrayRef iv, - llvm::ArrayRef wrapperSyms, - llvm::ArrayRef wrapperArgs, - llvm::omp::Directive directive, DataSharingProcessor &dsp) { - assert(wrapperSyms.size() == wrapperArgs.size() && - "Number of symbols and wrapper block arguments must match"); +static mlir::omp::LoopNestOp genLoopNestOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item, mlir::omp::LoopNestOperands &clauseOps, + llvm::ArrayRef iv, + llvm::ArrayRef< + std::pair> + wrapperArgs, + llvm::omp::Directive directive, DataSharingProcessor &dsp) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto ivCallback = [&](mlir::Operation *op) { - genLoopVars(op, converter, loc, iv, wrapperSyms, wrapperArgs); + genLoopVars(op, converter, loc, iv, wrapperArgs); return llvm::SmallVector(iv); }; @@ -1817,7 +1866,7 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ClauseProcessor cp(converter, semaCtx, item->clauses); cp.processCollapse(loc, eval, loopRelatedOps, iv); targetOp.getTripCountMutable().assign( - calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps)); + calculateTripCount(firOpBuilder, loc, loopRelatedOps)); } return loopNestOp; } @@ -1878,90 +1927,26 @@ static mlir::omp::ParallelOp genParallelOp( mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, mlir::omp::ParallelOperands &clauseOps, mlir::omp::NumThreadsClauseOps &numThreadsClauseOps, - llvm::ArrayRef reductionSyms, - llvm::ArrayRef reductionTypes, DataSharingProcessor *dsp, + const EntryBlockArgs &args, DataSharingProcessor *dsp, bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); - return llvm::SmallVector(reductionSyms); + auto genRegionEntryCB = [&](mlir::Operation *op) { + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast(op), args); + return llvm::to_vector(llvm::concat( + args.priv.syms, args.reduction.syms)); }; + assert((!enableDelayedPrivatization || dsp) && + "expected valid DataSharingProcessor"); OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_parallel) .setClauses(&item->clauses) - .setGenRegionEntryCb(reductionCallback) - .setGenSkeletonOnly(isComposite); - - if (!enableDelayedPrivatization) { - auto parallelOp = - genOpWithBody(genInfo, queue, item, clauseOps); - parallelOp.setComposite(isComposite); - if (numThreadsClauseOps.numThreads) { - if (parentTarget) - parentTarget.getNumThreadsMutable().assign( - numThreadsClauseOps.numThreads); - else - parallelOp.getNumThreadsMutable().assign( - numThreadsClauseOps.numThreads); - } - return parallelOp; - } - - assert(dsp && "expected valid DataSharingProcessor"); - auto genRegionEntryCB = [&](mlir::Operation *op) { - auto parallelOp = llvm::cast(op); - - llvm::SmallVector reductionLocs( - clauseOps.reductionVars.size(), loc); - - llvm::SmallVector allRegionArgTypes; - mergePrivateVarsInfo(parallelOp, reductionTypes, - llvm::function_ref{ - [](mlir::Value v) { return v.getType(); }}, - allRegionArgTypes, /*addBeforePrivate=*/false); - - llvm::SmallVector allRegionArgLocs; - mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs), - llvm::function_ref{ - [](mlir::Value v) { return v.getLoc(); }}, - allRegionArgLocs, /*addBeforePrivate=*/false); - - mlir::Region ®ion = parallelOp.getRegion(); - firOpBuilder.createBlock(®ion, /*insertPt=*/{}, allRegionArgTypes, - allRegionArgLocs); - - llvm::SmallVector allSymbols( - dsp->getDelayedPrivSymbols()); - allSymbols.append(reductionSyms.begin(), reductionSyms.end()); - - unsigned argIdx = 0; - for (const semantics::Symbol *arg : allSymbols) { - auto bind = [&](const semantics::Symbol *sym) { - mlir::BlockArgument blockArg = region.getArgument(argIdx); - ++argIdx; - converter.bindSymbol(*sym, - hlfir::translateToExtendedValue( - loc, firOpBuilder, hlfir::Entity{blockArg}, - /*contiguousHint=*/ - evaluate::IsSimplyContiguous( - *sym, converter.getFoldingContext())) - .first); - }; + .setGenRegionEntryCb(genRegionEntryCB) + .setGenSkeletonOnly(isComposite) + .setDataSharingProcessor(dsp); - if (const auto *commonDet = - arg->detailsIf()) { - for (const auto &mem : commonDet->objects()) - bind(&*mem); - } else - bind(arg); - } - - return allSymbols; - }; - - genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(dsp); auto parallelOp = genOpWithBody(genInfo, queue, item, clauseOps); parallelOp.setComposite(isComposite); @@ -1984,11 +1969,10 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item, const parser::OmpSectionBlocks §ionBlocks) { - llvm::SmallVector reductionTypes; - llvm::SmallVector reductionSyms; mlir::omp::SectionsOperands clauseOps; + llvm::SmallVector reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, - reductionTypes, reductionSyms); + reductionSyms); auto &builder = converter.getFirOpBuilder(); @@ -2021,15 +2005,20 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // SECTIONS construct. auto sectionsOp = builder.create(loc, clauseOps); - // create entry block with reduction variables as arguments - llvm::SmallVector blockArgLocs(reductionSyms.size(), loc); - builder.createBlock(§ionsOp->getRegion(0), {}, reductionTypes, - blockArgLocs); + // Create entry block with reduction variables as arguments. + EntryBlockArgs args; + // TODO: Add private syms and vars. + args.reduction.syms = reductionSyms; + args.reduction.vars = clauseOps.reductionVars; + + genEntryBlock(converter, args, sectionsOp.getRegion()); mlir::Operation *terminator = lower::genOpenMPTerminator(builder, sectionsOp, loc); auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast(op), args); return reductionSyms; }; @@ -2123,14 +2112,11 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, .getIsTargetDevice(); mlir::omp::TargetOperands clauseOps; - llvm::SmallVector mapSyms, devicePtrSyms, - deviceAddrSyms; - llvm::SmallVector mapLocs, devicePtrLocs, deviceAddrLocs; - llvm::SmallVector mapTypes, devicePtrTypes, deviceAddrTypes; + llvm::SmallVector mapSyms, isDevicePtrSyms, + hasDeviceAddrSyms; genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, mapSyms, mapLocs, - mapTypes, deviceAddrSyms, deviceAddrLocs, deviceAddrTypes, - devicePtrSyms, devicePtrLocs, devicePtrTypes); + processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, + isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -2239,14 +2225,23 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, clauseOps.mapVars.push_back(mapOp); mapSyms.push_back(&sym); - mapLocs.push_back(baseOp.getLoc()); - mapTypes.push_back(baseOp.getType()); }; lower::pft::visitAllSymbols(eval, captureImplicitMap); auto targetOp = firOpBuilder.create(loc, clauseOps); - genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, mapSyms, - mapLocs, mapTypes, loc, queue, item, dsp); + + llvm::SmallVector mapBaseValues; + extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); + + EntryBlockArgs args; + // TODO: Add in_reduction syms and vars. + args.map.syms = mapSyms; + args.map.vars = mapBaseValues; + args.priv.syms = dsp.getDelayedPrivSymbols(); + args.priv.vars = clauseOps.privateVars; + + genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, + queue, item, dsp); return targetOp; } @@ -2258,18 +2253,28 @@ genTargetDataOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; mlir::omp::TargetDataOperands clauseOps; - llvm::SmallVector useDeviceTypes; - llvm::SmallVector useDeviceLocs; - llvm::SmallVector useDeviceSyms; + llvm::SmallVector useDeviceAddrSyms, + useDevicePtrSyms; genTargetDataClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); + clauseOps, useDeviceAddrSyms, useDevicePtrSyms); auto targetDataOp = converter.getFirOpBuilder().create(loc, clauseOps); - genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, - useDeviceSyms, useDeviceLocs, useDeviceTypes, loc, - queue, item); + + llvm::SmallVector useDeviceAddrBaseValues, + useDevicePtrBaseValues; + extractMappedBaseValues(clauseOps.useDeviceAddrVars, useDeviceAddrBaseValues); + extractMappedBaseValues(clauseOps.useDevicePtrVars, useDevicePtrBaseValues); + + EntryBlockArgs args; + args.useDeviceAddr.syms = useDeviceAddrSyms; + args.useDeviceAddr.vars = useDeviceAddrBaseValues; + args.useDevicePtr.syms = useDevicePtrSyms; + args.useDevicePtr.vars = useDevicePtrBaseValues; + + genBodyOfTargetDataOp(converter, symTable, semaCtx, eval, targetDataOp, args, + loc, queue, item); return targetDataOp; } @@ -2368,21 +2373,28 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::NumTeamsClauseOps numTeamsClauseOps; mlir::omp::ThreadLimitClauseOps threadLimitClauseOps; llvm::SmallVector reductionSyms; - llvm::SmallVector reductionTypes; genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, evalOutsideTarget, clauseOps, numTeamsClauseOps, - threadLimitClauseOps, reductionTypes, reductionSyms); + threadLimitClauseOps, reductionSyms); - auto reductionCallback = [&](mlir::Operation *op) { - genReductionVars(op, converter, loc, reductionSyms, reductionTypes); - return llvm::SmallVector(reductionSyms); + EntryBlockArgs args; + // TODO: Add private syms and vars. + args.reduction.syms = reductionSyms; + args.reduction.vars = clauseOps.reductionVars; + + auto genRegionEntryCB = [&](mlir::Operation *op) { + genEntryBlock(converter, args, op->getRegion(0)); + bindEntryBlockArgs( + converter, llvm::cast(op), args); + return llvm::to_vector(llvm::concat( + args.priv.syms, args.reduction.syms)); }; auto teamsOp = genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) - .setClauses(&item->clauses) - .setGenRegionEntryCb(reductionCallback), + .setClauses(&item->clauses) + .setGenRegionEntryCb(genRegionEntryCB), queue, item, clauseOps); if (numTeamsClauseOps.numTeamsUpper) { @@ -2427,22 +2439,20 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter, enableDelayedPrivatizationStaging, &symTable); dsp.processStep1(); dsp.processStep2(&distributeClauseOps); - llvm::SmallVector privateVarTypes{}; - - for (mlir::Value privateVar : distributeClauseOps.privateVars) - privateVarTypes.push_back(privateVar.getType()); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); + EntryBlockArgs distributeArgs; + distributeArgs.priv.syms = dsp.getDelayedPrivSymbols(); + distributeArgs.priv.vars = distributeClauseOps.privateVars; auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, privateVarTypes); + converter, loc, distributeClauseOps, distributeArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, dsp.getDelayedPrivSymbols(), - distributeOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{distributeOp, distributeArgs}}, llvm::omp::Directive::OMPD_distribute, dsp); } @@ -2454,10 +2464,9 @@ static void genStandaloneDo(lower::AbstractConverter &converter, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; mlir::omp::WsloopOperands wsloopClauseOps; - llvm::SmallVector reductionSyms; - llvm::SmallVector reductionTypes; + llvm::SmallVector wsloopReductionSyms; genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - wsloopClauseOps, reductionTypes, reductionSyms); + wsloopClauseOps, wsloopReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -2471,13 +2480,15 @@ static void genStandaloneDo(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, reductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, reductionSyms, - wsloopOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_do, dsp); } @@ -2496,11 +2507,10 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, mlir::omp::ParallelOperands parallelClauseOps; mlir::omp::NumThreadsClauseOps numThreadsClauseOps; - llvm::SmallVector reductionSyms; - llvm::SmallVector reductionTypes; + llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, - reductionTypes, reductionSyms); + parallelReductionSyms); std::optional dsp; if (enableDelayedPrivatization) { @@ -2510,9 +2520,15 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, dsp->processStep1(); dsp->processStep2(¶llelClauseOps); } + + EntryBlockArgs parallelArgs; + if (dsp) + parallelArgs.priv.syms = dsp->getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, - parallelClauseOps, numThreadsClauseOps, reductionSyms, - reductionTypes, + parallelClauseOps, numThreadsClauseOps, parallelArgs, enableDelayedPrivatization ? &dsp.value() : nullptr, /*isComposite=*/false, evalOutsideTarget ? targetOp : nullptr); } @@ -2524,7 +2540,9 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, @@ -2538,13 +2556,15 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, loopNestClauseOps, iv); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item, - loopNestClauseOps, iv, - /*wrapperSyms=*/{}, simdOp.getRegion().getArguments(), + loopNestClauseOps, iv, {{simdOp, simdArgs}}, llvm::omp::Directive::OMPD_simd, dsp); } @@ -2582,10 +2602,9 @@ static void genCompositeDistributeParallelDo( mlir::omp::ParallelOperands parallelClauseOps; mlir::omp::NumThreadsClauseOps numThreadsClauseOps; llvm::SmallVector parallelReductionSyms; - llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, - parallelReductionTypes, parallelReductionSyms); + parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2593,10 +2612,14 @@ static void genCompositeDistributeParallelDo( dsp.processStep1(); dsp.processStep2(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, numThreadsClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true, - evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, numThreadsClauseOps, parallelArgs, &dsp, + /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2605,9 +2628,8 @@ static void genCompositeDistributeParallelDo( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; @@ -2615,27 +2637,23 @@ static void genCompositeDistributeParallelDo( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector( - llvm::concat(distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, doItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {wsloopOp, wsloopArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do, dsp); } @@ -2660,10 +2678,9 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::ParallelOperands parallelClauseOps; mlir::omp::NumThreadsClauseOps numThreadsClauseOps; llvm::SmallVector parallelReductionSyms; - llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, - parallelReductionTypes, parallelReductionSyms); + parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2671,10 +2688,14 @@ static void genCompositeDistributeParallelDoSimd( dsp.processStep1(); dsp.processStep2(¶llelClauseOps); + EntryBlockArgs parallelArgs; + parallelArgs.priv.syms = dsp.getDelayedPrivSymbols(); + parallelArgs.priv.vars = parallelClauseOps.privateVars; + parallelArgs.reduction.syms = parallelReductionSyms; + parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, numThreadsClauseOps, parallelReductionSyms, - parallelReductionTypes, &dsp, /*isComposite=*/true, - evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, numThreadsClauseOps, parallelArgs, &dsp, + /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2683,12 +2704,13 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; @@ -2696,32 +2718,33 @@ static void genCompositeDistributeParallelDoSimd( loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - auto &wrapperSyms = wsloopReductionSyms; - - auto wrapperArgs = llvm::to_vector(llvm::concat( - distributeOp.getRegion().getArguments(), - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wrapperSyms, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, + {wsloopOp, wsloopArgs}, + {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_parallel_do_simd, dsp); } @@ -2744,7 +2767,9 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loc, distributeClauseOps); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2761,26 +2786,23 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Populate entry block arguments with private variables. + EntryBlockArgs distributeArgs; + // TODO: Add private syms and vars. auto distributeOp = genWrapperOp( - converter, loc, distributeClauseOps, /*blockArgTypes=*/{}); + converter, loc, distributeClauseOps, distributeArgs); distributeOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol order and the block argument order match, so that the - // symbol-value bindings created are correct. - // TODO: Add omp.distribute private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector( - llvm::concat(distributeOp.getRegion().getArguments(), - simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs, + loopNestClauseOps, iv, + {{distributeOp, distributeArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_distribute_simd, dsp); } @@ -2799,12 +2821,13 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, // Clause processing. mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector wsloopReductionSyms; - llvm::SmallVector wsloopReductionTypes; genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc, - wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms); + wsloopClauseOps, wsloopReductionSyms); mlir::omp::SimdOperands simdClauseOps; - genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps); + llvm::SmallVector simdReductionSyms; + genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, + simdReductionSyms); // TODO: Support delayed privatization. DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, @@ -2821,25 +2844,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, loopNestClauseOps, iv); // Operation creation. - // TODO: Add private variables to entry block arguments. + EntryBlockArgs wsloopArgs; + // TODO: Add private syms and vars. + wsloopArgs.reduction.syms = wsloopReductionSyms; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; auto wsloopOp = genWrapperOp( - converter, loc, wsloopClauseOps, wsloopReductionTypes); + converter, loc, wsloopClauseOps, wsloopArgs); wsloopOp.setComposite(/*val=*/true); - // TODO: Populate entry block arguments with reduction and private variables. - auto simdOp = genWrapperOp(converter, loc, simdClauseOps, - /*blockArgTypes=*/{}); + EntryBlockArgs simdArgs; + // TODO: Add private syms and vars. + simdArgs.reduction.syms = simdReductionSyms; + simdArgs.reduction.vars = simdClauseOps.reductionVars; + auto simdOp = + genWrapperOp(converter, loc, simdClauseOps, simdArgs); simdOp.setComposite(/*val=*/true); - // Construct wrapper entry block list and associated symbols. It is important - // that the symbol and block argument order match, so that the symbol-value - // bindings created are correct. - // TODO: Add omp.wsloop private and omp.simd private and reduction args. - auto wrapperArgs = llvm::to_vector(llvm::concat( - wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments())); - genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, - loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs, + loopNestClauseOps, iv, + {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}}, llvm::omp::Directive::OMPD_do_simd, dsp); } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index c87182abe3d187..f35e425777141d 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -723,7 +723,7 @@ void ReductionProcessor::addDeclareReduction( llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl *reductionSymbols) { + llvm::SmallVectorImpl &reductionSymbols) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); if (std::get>( @@ -754,8 +754,7 @@ void ReductionProcessor::addDeclareReduction( fir::FirOpBuilder &builder = converter.getFirOpBuilder(); for (const Object &object : objectList) { const semantics::Symbol *symbol = object.sym(); - if (reductionSymbols) - reductionSymbols->push_back(symbol); + reductionSymbols.push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); mlir::Type eleType; auto refType = mlir::dyn_cast_or_null(symVal.getType()); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 580d2cc54da98b..e2a06a257927c5 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -125,8 +125,7 @@ class ReductionProcessor { llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl *reductionSymbols = - nullptr); + llvm::SmallVectorImpl &reductionSymbols); }; template diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 7291d5b7920790..c26bbe89ee0bb1 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -454,18 +454,15 @@ void insertChildMapInfoIntoParent( lower::StatementContext &stmtCtx, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymbols) { - + llvm::SmallVectorImpl &mapSyms) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); for (auto indices : parentMemberIndices) { bool parentExists = false; size_t parentIdx; - for (parentIdx = 0; parentIdx < mapSymbols->size(); ++parentIdx) - if ((*mapSymbols)[parentIdx] == indices.first.sym()) { + for (parentIdx = 0; parentIdx < mapSyms.size(); ++parentIdx) + if (mapSyms[parentIdx] == indices.first.sym()) { parentExists = true; break; } @@ -524,12 +521,7 @@ void insertChildMapInfoIntoParent( extendBoundsFromMultipleSubscripts(converter, stmtCtx, mapOp, indices.second.parentObjList); mapOperands.push_back(mapOp); - if (mapSymTypes) - mapSymTypes->push_back(mapOp.getType()); - if (mapSymLocs) - mapSymLocs->push_back(mapOp.getLoc()); - if (mapSymbols) - mapSymbols->push_back(indices.first.sym()); + mapSyms.push_back(indices.first.sym()); } } }