Skip to content

Commit

Permalink
[𝘀𝗽𝗿] initial version
Browse files Browse the repository at this point in the history
Created using spr 1.3.4
  • Loading branch information
agozillon committed Oct 4, 2024
2 parents f873fc3 + 7d56f36 commit 1cd1143
Show file tree
Hide file tree
Showing 52 changed files with 3,232 additions and 471 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
llvm::ArrayRef<mlir::Value> lenParams,
bool asTarget = false);

/// Create a two dimensional ArrayAttr containing integer data as
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
mlir::ArrayAttr create2DI64ArrayAttr(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);

/// Create a temporary using `fir.alloca`. This function does not hoist.
/// It is the callers responsibility to set the insertion point if
/// hoisting is required.
Expand Down
76 changes: 44 additions & 32 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,17 @@ void ClauseProcessor::processMapObjects(
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
std::optional<omp::Object> parentObj;

lower::AddrAndBoundsInfo info =
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
Expand All @@ -907,28 +908,46 @@ void ClauseProcessor::processMapObjects(
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

mlir::Value baseOp = info.rawInput;
if (object.sym()->owner().IsDerivedType()) {
omp::ObjectList objectList = gatherObjects(object, semaCtx);
assert(!objectList.empty() &&
"could not find parent objects of derived type member");
parentObj = objectList[0];
parentMemberIndices.emplace(parentObj.value(),
OmpMapParentAndMemberData{});

if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
llvm::SmallVector<int64_t> indices;
generateMemberPlacementIndices(object, indices, semaCtx);
baseOp = createParentSymAndGenIntermediateMaps(
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
parentMemberIndices[parentObj.value()], asFortran.str(),
mapTypeBits);
}
}

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value baseOp = info.rawInput;
auto location = mlir::NameLoc::get(
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
baseOp.getLoc());
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, location, baseOp,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());

if (object.sym()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
if (parentObj.has_value()) {
addChildIndexAndMapToParent(
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
if (mapSyms)
mapSyms->push_back(object.sym());
mapSyms->push_back(object.sym());
if (mapSymTypes)
mapSymTypes->push_back(baseOp.getType());
if (mapSymLocs)
Expand All @@ -949,9 +968,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
Expand Down Expand Up @@ -1003,17 +1020,15 @@ bool ClauseProcessor::processMap(
mapSymLocs, mapSymTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
*ptrMapSyms, mapSymTypes, mapSymLocs);

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
return clauseFound;
}

bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;

auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
Expand All @@ -1034,9 +1049,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
clauseFound =
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
insertChildMapInfoIntoParent(
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
return clauseFound;
}

Expand Down Expand Up @@ -1110,9 +1125,7 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const parser::CharBlock &source) {
Expand All @@ -1125,9 +1138,9 @@ bool ClauseProcessor::processUseDeviceAddr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDeviceAddrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand All @@ -1136,9 +1149,8 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const parser::CharBlock &source) {
Expand All @@ -1151,9 +1163,9 @@ bool ClauseProcessor::processUseDevicePtr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDevicePtrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ class ClauseProcessor {
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
Expand Down
11 changes: 11 additions & 0 deletions flang/lib/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ struct IdTyTemplate {
return designator == other.designator;
}

// Defining an "ordering" which allows types derived from this to be
// utilised in maps and other containers that require comparison
// operators for ordering
bool operator<(const IdTyTemplate &other) const {
return symbol < other.symbol;
}

operator bool() const { return symbol != nullptr; }
};

Expand All @@ -76,6 +83,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
const std::optional<ExprTy> &ref() const { return identity.designator; }

bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
return identity < other.identity;
}

IdTy identity;
};
} // namespace tomp::type
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ static void genBodyOfTargetOp(
firOpBuilder, copyVal.getLoc(), copyVal,
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
/*members=*/llvm::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
/*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
Expand Down Expand Up @@ -1792,7 +1792,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{},
name.str(), bounds, /*members=*/{},
/*membersIndex=*/mlir::DenseIntElementsAttr{},
/*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapFlag),
Expand Down
Loading

0 comments on commit 1cd1143

Please sign in to comment.