Skip to content

Commit

Permalink
[Flang][OpenMP] Improve entry block argument creation and binding
Browse files Browse the repository at this point in the history
Commit cherry-picked from PR llvm#110267. Will be removed when rebasing PR stack on
top of a more recent amd-trunk-dev branch.
  • Loading branch information
skatrak committed Oct 10, 2024
1 parent 15949dd commit c3518fb
Show file tree
Hide file tree
Showing 7 changed files with 578 additions and 615 deletions.
4 changes: 1 addition & 3 deletions flang/include/flang/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ void insertChildMapInfoIntoParent(
Fortran::lower::StatementContext &stmtCtx,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols);
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSymbols);

mlir::Type getLoopVarType(lower::AbstractConverter &converter,
std::size_t loopVarTypeSize);
Expand Down
78 changes: 22 additions & 56 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter,
static void addUseDeviceClause(
lower::AbstractConverter &converter, const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
for (mlir::Value &operand : operands)
checkMapType(operand.getLoc(), operand.getType());
useDeviceTypes.push_back(operand.getType());
useDeviceLocs.push_back(operand.getLoc());
}

for (const omp::Object &object : objects)
useDeviceSyms.push_back(object.sym());
}
Expand Down Expand Up @@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {

bool ClauseProcessor::processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
isDeviceSyms);
});
}

Expand All @@ -864,14 +858,12 @@ bool ClauseProcessor::processIf(

bool ClauseProcessor::processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
isDeviceSyms);
});
}

Expand All @@ -891,9 +883,7 @@ void ClauseProcessor::processMapObjects(
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
Expand Down Expand Up @@ -948,21 +938,15 @@ void ClauseProcessor::processMapObjects(
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
mapSyms->push_back(object.sym());
if (mapSymTypes)
mapSymTypes->push_back(baseOp.getType());
if (mapSymLocs)
mapSymLocs->push_back(baseOp.getLoc());
mapSyms.push_back(object.sym());
}
}
}

bool ClauseProcessor::processMap(
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
// We always require tracking of symbols, even if the caller does not,
// so we create an optionally used local set of symbols when the mapSyms
// argument is not present.
Expand Down Expand Up @@ -1018,13 +1002,11 @@ bool ClauseProcessor::processMap(

processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
parentMemberIndices, result.mapVars, ptrMapSyms,
mapSymLocs, mapSymTypes);
parentMemberIndices, result.mapVars, *ptrMapSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
result.mapVars, *ptrMapSyms);
return clauseFound;
}

Expand All @@ -1044,7 +1026,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,

processMapObjects(stmtCtx, clauseLocation, std::get<ObjectList>(clause.t),
mapTypeBits, parentMemberIndices, result.mapVars,
&mapSymbols);
mapSymbols);
};

bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
Expand All @@ -1053,7 +1035,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,

insertChildMapInfoIntoParent(
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
mapSymbols);
return clauseFound;
}

Expand All @@ -1071,34 +1053,24 @@ bool ClauseProcessor::processNontemporal(

bool ClauseProcessor::processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionSyms));

if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}

if (outReductionSyms)
llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
});
}

Expand All @@ -1124,8 +1096,6 @@ bool ClauseProcessor::processEnter(

bool ClauseProcessor::processUseDeviceAddr(
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
Expand All @@ -1137,19 +1107,16 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDeviceAddrVars,
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
useDeviceSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
result.useDeviceAddrVars, useDeviceSyms);
return clauseFound;
}

bool ClauseProcessor::processUseDevicePtr(
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

Expand All @@ -1162,12 +1129,11 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.useDevicePtrVars,
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
useDeviceSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
result.useDevicePtrVars, useDeviceSyms);
return clauseFound;
}

Expand Down
38 changes: 12 additions & 26 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ class ClauseProcessor {
mlir::omp::FinalClauseOps &result) const;
bool processHasDeviceAddr(
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
Expand Down Expand Up @@ -104,43 +102,33 @@ class ClauseProcessor {
mlir::omp::IfClauseOps &result) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

// This method is used to process a map clause.
// The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
bool processMap(
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
// The optional parameter mapSyms is used to store the original Fortran symbol
// for the map operands. It may be used later on to create the block_arguments
// for some of the directives that require it.
bool processMap(mlir::Location currentLocation,
lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
nullptr) const;
bool processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result);
bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
nullptr) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
bool processUseDevicePtr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDevicePtrClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;

// Call this method for these clauses that should be supported but are not
Expand Down Expand Up @@ -180,9 +168,7 @@ class ClauseProcessor {
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;

lower::AbstractConverter &converter;
semantics::SemanticsContext &semaCtx;
Expand Down
Loading

0 comments on commit c3518fb

Please sign in to comment.