Skip to content

Merge Declare Mapper patches. #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8921,17 +8921,17 @@ static void emitOffloadingArraysAndArgs(
};

auto CustomMapperCB = [&](unsigned int I) {
llvm::Value *MFunc = nullptr;
llvm::Function *MFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
Info.HasMapper = true;
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
}
return MFunc;
};
OMPBuilder.emitOffloadingArraysAndArgs(
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
ForEndCall, DeviceAddrCB, CustomMapperCB);
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
IsNonContiguous, ForEndCall, DeviceAddrCB));
}

/// Check for inner distribute directive.
Expand Down Expand Up @@ -9124,24 +9124,24 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
return CombinedInfo;
};

auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
auto CustomMapperCB = [&](unsigned I) {
llvm::Function *MapperFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
// Call the corresponding mapper function.
*MapperFunc = getOrCreateUserDefinedMapperFunc(
MapperFunc = getOrCreateUserDefinedMapperFunc(
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
assert(*MapperFunc && "Expect a valid mapper function is available.");
return true;
assert(MapperFunc && "Expect a valid mapper function is available.");
}
return false;
return MapperFunc;
};

SmallString<64> TyStr;
llvm::raw_svector_ostream Out(TyStr);
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
std::string Name = getName({"omp_mapper", TyStr, D->getName()});

auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
ElemTy, Name, CustomMapperCB);
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
UDMMap.try_emplace(D, NewFn);
if (CGF)
FunctionUDMMap[CGF->CurFn].push_back(D);
Expand Down Expand Up @@ -10493,7 +10493,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
};

auto CustomMapperCB = [&](unsigned int I) {
llvm::Value *MFunc = nullptr;
llvm::Function *MFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
Info.HasMapper = true;
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
Expand All @@ -10513,7 +10513,8 @@ void CGOpenMPRuntime::emitTargetDataCalls(
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
cantFail(OMPBuilder.createTargetData(
OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc));
CustomMapperCB,
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc));
CGF.Builder.restoreIP(AfterIP);
}

Expand Down
3 changes: 2 additions & 1 deletion flang/include/flang/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::ArrayRef<mlir::Value> members,
mlir::ArrayAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap = false);
bool partialMap = false,
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());

void insertChildMapInfoIntoParent(
Fortran::lower::AbstractConverter &converter,
Expand Down
33 changes: 28 additions & 5 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,8 +969,11 @@ void ClauseProcessor::processMapObjects(
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::StringRef mapperIdNameRef) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr mapperId;
std::string mapperIdName = mapperIdNameRef.str();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
Expand Down Expand Up @@ -1003,6 +1006,20 @@ void ClauseProcessor::processMapObjects(
}
}

if (!mapperIdName.empty()) {
if (mapperIdName == "default") {
auto &typeSpec = object.sym()->owner().IsDerivedType()
? *object.sym()->owner().derivedTypeSpec()
: object.sym()->GetType()->derivedTypeSpec();
mapperIdName = typeSpec.name().ToString() + ".default";
mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope());
}
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
"mapper not found");
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);
mapperIdName.clear();
}
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
Expand All @@ -1016,7 +1033,8 @@ void ClauseProcessor::processMapObjects(
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(),
/*partialMap=*/false, mapperId);

if (parentObj.has_value()) {
parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
Expand Down Expand Up @@ -1047,6 +1065,7 @@ bool ClauseProcessor::processMap(
const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t;
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
std::string mapperIdName;
// If the map type is specified, then process it else Tofrom is the
// default.
Map::MapType type = mapType.value_or(Map::MapType::Tofrom);
Expand Down Expand Up @@ -1090,13 +1109,17 @@ bool ClauseProcessor::processMap(
"Support for iterator modifiers is not implemented yet");
}
if (mappers) {
TODO(currentLocation,
"Support for mapper modifiers is not implemented yet");
assert(mappers->size() == 1 && "more than one mapper");
mapperIdName = mappers->front().v.id().symbol->name().ToString();
if (mapperIdName != "default")
mapperIdName = converter.mangleName(
mapperIdName, mappers->front().v.id().symbol->owner());
}

processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
parentMemberIndices, result.mapVars, *ptrMapSyms);
parentMemberIndices, result.mapVars, *ptrMapSyms,
mapperIdName);
};

bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ class ClauseProcessor {
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::StringRef mapperIdNameRef = "") const;

lower::AbstractConverter &converter;
semantics::SemanticsContext &semaCtx;
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::ArrayRef<mlir::Value> members,
mlir::ArrayAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap) {
bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
Expand All @@ -149,6 +149,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
mapperId,
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
return op;
Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ mlir::omp::MapInfoOp createMapInfoOp(
mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy, bool partialMap = false) {
mlir::Type retTy, bool partialMap = false,
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) {
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
Expand All @@ -70,6 +71,7 @@ mlir::omp::MapInfoOp createMapInfoOp(
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
mapperId,
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name), builder.getBoolAttr(partialMap));

Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class MapInfoFinalizationPass
/*members=*/mlir::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
/*mapperId*/ mlir::FlatSymbolRefAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*name=*/builder.getStringAttr(""),
Expand Down Expand Up @@ -331,7 +332,8 @@ class MapInfoFinalizationPass
builder.getIntegerAttr(
builder.getIntegerType(64, false),
getDescriptorMapType(op.getMapType().value_or(0), target)),
op.getMapCaptureTypeAttr(), op.getNameAttr(),
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
op.getNameAttr(),
/*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newDescParentMapOp.getResult());
op->erase();
Expand Down Expand Up @@ -629,6 +631,7 @@ class MapInfoFinalizationPass
// /*members=*/mlir::ValueRange{},
// /*members_index=*/mlir::ArrayAttr{},
// /*bounds=*/bounds, op.getMapTypeAttr(),
// /*mapperId*/ mlir::FlatSymbolRefAttr(),
// builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
// mlir::omp::VariableCaptureKind::ByRef),
// builder.getStringAttr(op.getNameAttr().strref() + "." +
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass
/*bounds=*/ValueRange{},
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
mapTypeTo),
/*mapperId*/ mlir::FlatSymbolRefAttr(),
builder.getAttr<omp::VariableCaptureKindAttr>(
omp::VariableCaptureKind::ByRef),
StringAttr(), builder.getBoolAttr(false));
Expand Down
27 changes: 23 additions & 4 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -936,9 +936,9 @@ func.func @omp_map_info_descriptor_type_conversion(%arg0 : !fir.ref<!fir.box<!fi
%1 = omp.map.info var_ptr(%0 : !fir.llvm_ptr<!fir.ref<i32>>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
// CHECK: %[[DESC_MAP:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, delete) capture(ByRef) members(%[[MEMBER_MAP]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""}
%2 = omp.map.info var_ptr(%arg0 : !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.box<!fir.heap<i32>>) map_clauses(always, delete) capture(ByRef) members(%1 : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.heap<i32>>> {name = ""}
// CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr)
// CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr)
omp.target_exit_data map_entries(%2 : !fir.ref<!fir.box<!fir.heap<i32>>>)
return
return
}

// -----
Expand All @@ -956,8 +956,8 @@ func.func @omp_map_info_derived_type_explicit_member_conversion(%arg0 : !fir.ref
%3 = fir.field_index real, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>
%4 = fir.coordinate_of %arg0, %3 : (!fir.ref<!fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.field) -> !fir.ref<f32>
// CHECK: %[[MAP_MEMBER_2:.*]] = omp.map.info var_ptr(%[[GEP_2]] : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtype%real"}
%5 = omp.map.info var_ptr(%4 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "dtype%real"}
// CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true}
%5 = omp.map.info var_ptr(%4 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "dtype%real"}
// CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true}
%6 = omp.map.info var_ptr(%arg0 : !fir.ref<!fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) members(%2, %5 : [2], [0] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "dtype", partial_map = true}
// CHECK: omp.target map_entries(%[[MAP_MEMBER_1]] -> %[[ARG_1:.*]], %[[MAP_MEMBER_2]] -> %[[ARG_2:.*]], %[[MAP_PARENT]] -> %[[ARG_3:.*]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
omp.target map_entries(%2 -> %arg1, %5 -> %arg2, %6 -> %arg3 : !fir.ref<i32>, !fir.ref<f32>, !fir.ref<!fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>>) {
Expand Down Expand Up @@ -1279,3 +1279,22 @@ func.func @map_nested_dtype_alloca_mem2(%arg0 : !fir.ref<!fir.type<_QFRecTy{i:f3
}
return
}

// -----

// CHECK-LABEL: omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> {
omp.declare_mapper @my_mapper : !fir.type<_QFdeclare_mapperTmy_type{data:i32}> {
// CHECK: ^bb0(%[[VAL_0:.*]]: !llvm.ptr):
^bb0(%0: !fir.ref<!fir.type<_QFdeclare_mapperTmy_type{data:i32}>>):
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
%1 = fir.field_index data, !fir.type<_QFdeclare_mapperTmy_type{data:i32}>
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>
%2 = fir.coordinate_of %0, %1 : (!fir.ref<!fir.type<_QFdeclare_mapperTmy_type{data:i32}>>, !fir.field) -> !fir.ref<i32>
// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%[[VAL_4:.*]]"}
%3 = omp.map.info var_ptr(%2 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var%data"}
// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]] : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true}
%4 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFdeclare_mapperTmy_type{data:i32}>>, !fir.type<_QFdeclare_mapperTmy_type{data:i32}>) map_clauses(tofrom) capture(ByRef) members(%3 : [0] : !fir.ref<i32>) -> !fir.ref<!fir.type<_QFdeclare_mapperTmy_type{data:i32}>> {name = "var", partial_map = true}
// CHECK: omp.declare_mapper.info map_entries(%[[VAL_5]], %[[VAL_3]] : !llvm.ptr, !llvm.ptr)
omp.declare_mapper.info map_entries(%4, %3 : !fir.ref<!fir.type<_QFdeclare_mapperTmy_type{data:i32}>>, !fir.ref<i32>)
// CHECK: }
}
Loading
Loading