Skip to content

Commit 784c7fb

Browse files
committed
[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers
This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive. Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well. Depends on #121005
1 parent 55e38d7 commit 784c7fb

File tree

7 files changed

+437
-99
lines changed

7 files changed

+437
-99
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8888,8 +8888,8 @@ static void emitOffloadingArraysAndArgs(
88888888
return MFunc;
88898889
};
88908890
OMPBuilder.emitOffloadingArraysAndArgs(
8891-
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
8892-
ForEndCall, DeviceAddrCB, CustomMapperCB);
8891+
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8892+
IsNonContiguous, ForEndCall, DeviceAddrCB);
88938893
}
88948894

88958895
/// Check for inner distribute directive.
@@ -9098,9 +9098,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
90989098
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
90999099
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91009100

9101-
auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
9102-
ElemTy, Name, CustomMapperCB);
9103-
UDMMap.try_emplace(D, NewFn);
9101+
llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
9102+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
9103+
assert(NewFn && "Unexpected error in emitUserDefinedMapper");
9104+
UDMMap.try_emplace(D, *NewFn);
91049105
if (CGF)
91059106
FunctionUDMMap[CGF->CurFn].push_back(D);
91069107
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,7 @@ class OpenMPIRBuilder {
23992399
CurInfo.NonContigInfo.Strides.end());
24002400
}
24012401
};
2402+
using MapInfosOrErrorTy = Expected<MapInfosTy &>;
24022403

24032404
/// Callback function type for functions emitting the host fallback code that
24042405
/// is executed when the kernel launch fails. It takes an insertion point as
@@ -2475,9 +2476,9 @@ class OpenMPIRBuilder {
24752476
/// including base pointers, pointers, sizes, map types, user-defined mappers.
24762477
void emitOffloadingArrays(
24772478
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2478-
TargetDataInfo &Info, bool IsNonContiguous = false,
2479-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2480-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2479+
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
2480+
bool IsNonContiguous = false,
2481+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24812482

24822483
/// Allocates memory for and populates the arrays required for offloading
24832484
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2488,9 +2489,9 @@ class OpenMPIRBuilder {
24882489
void emitOffloadingArraysAndArgs(
24892490
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
24902491
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2492+
function_ref<Value *(unsigned int)> CustomMapperCB,
24912493
bool IsNonContiguous = false, bool ForEndCall = false,
2492-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2493-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2494+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24942495

24952496
/// Creates offloading entry for the provided entry ID \a ID, address \a
24962497
/// Addr, size \a Size, and flags \a Flags.
@@ -2950,12 +2951,12 @@ class OpenMPIRBuilder {
29502951
/// \param FuncName Optional param to specify mapper function name.
29512952
/// \param CustomMapperCB Optional callback to generate code related to
29522953
/// custom mappers.
2953-
Function *emitUserDefinedMapper(
2954-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
2955-
llvm::Value *BeginArg)>
2954+
Expected<Function *> emitUserDefinedMapper(
2955+
function_ref<MapInfosOrErrorTy(
2956+
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29562957
PrivAndGenMapInfoCB,
29572958
llvm::Type *ElemTy, StringRef FuncName,
2958-
function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
2959+
function_ref<bool(unsigned int, Function **)> CustomMapperCB);
29592960

29602961
/// Generator for '#omp target data'
29612962
///
@@ -2969,21 +2970,21 @@ class OpenMPIRBuilder {
29692970
/// \param IfCond Value which corresponds to the if clause condition.
29702971
/// \param Info Stores all information realted to the Target Data directive.
29712972
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
2973+
/// \param CustomMapperCB Callback to generate code related to
2974+
/// custom mappers.
29722975
/// \param BodyGenCB Optional Callback to generate the region code.
29732976
/// \param DeviceAddrCB Optional callback to generate code related to
29742977
/// use_device_ptr and use_device_addr.
2975-
/// \param CustomMapperCB Optional callback to generate code related to
2976-
/// custom mappers.
29772978
InsertPointOrErrorTy createTargetData(
29782979
const LocationDescription &Loc, InsertPointTy AllocaIP,
29792980
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
29802981
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
2982+
function_ref<Value *(unsigned int)> CustomMapperCB,
29812983
omp::RuntimeFunction *MapperFunc = nullptr,
29822984
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
29832985
BodyGenTy BodyGenType)>
29842986
BodyGenCB = nullptr,
29852987
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2986-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
29872988
Value *SrcLocInfo = nullptr);
29882989

29892990
using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2999,6 +3000,7 @@ class OpenMPIRBuilder {
29993000
/// \param IsOffloadEntry whether it is an offload entry.
30003001
/// \param CodeGenIP The insertion point where the call to the outlined
30013002
/// function should be emitted.
3003+
/// \param Info Stores all information realted to the Target directive.
30023004
/// \param EntryInfo The entry information about the function.
30033005
/// \param DefaultAttrs Structure containing the default attributes, including
30043006
/// numbers of threads and teams to launch the kernel with.
@@ -3010,20 +3012,23 @@ class OpenMPIRBuilder {
30103012
/// \param BodyGenCB Callback that will generate the region code.
30113013
/// \param ArgAccessorFuncCB Callback that will generate accessors
30123014
/// instructions for passed in target arguments where neccessary
3015+
/// \param CustomMapperCB Callback to generate code related to
3016+
/// custom mappers.
30133017
/// \param Dependencies A vector of DependData objects that carry
30143018
/// dependency information as passed in the depend clause
30153019
/// \param HasNowait Whether the target construct has a `nowait` clause or
30163020
/// not.
30173021
InsertPointOrErrorTy createTarget(
30183022
const LocationDescription &Loc, bool IsOffloadEntry,
30193023
OpenMPIRBuilder::InsertPointTy AllocaIP,
3020-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
3024+
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
30213025
TargetRegionEntryInfo &EntryInfo,
30223026
const TargetKernelDefaultAttrs &DefaultAttrs,
30233027
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
30243028
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30253029
TargetBodyGenCallbackTy BodyGenCB,
30263030
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3031+
function_ref<Value *(unsigned int)> CustomMapperCB,
30273032
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30283033

30293034
/// Returns __kmpc_for_static_init_* runtime function for the specified

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6555,12 +6555,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65556555
const LocationDescription &Loc, InsertPointTy AllocaIP,
65566556
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
65576557
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6558+
function_ref<Value *(unsigned int)> CustomMapperCB,
65586559
omp::RuntimeFunction *MapperFunc,
65596560
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
65606561
BodyGenTy BodyGenType)>
65616562
BodyGenCB,
6562-
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6563-
function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6563+
function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
65646564
if (!updateToLocation(Loc))
65656565
return InsertPointTy();
65666566

@@ -6586,8 +6586,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65866586
InsertPointTy CodeGenIP) -> Error {
65876587
MapInfo = &GenMapInfoCB(Builder.saveIP());
65886588
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6589-
/*IsNonContiguous=*/true, DeviceAddrCB,
6590-
CustomMapperCB);
6589+
CustomMapperCB,
6590+
/*IsNonContiguous=*/true, DeviceAddrCB);
65916591

65926592
TargetDataRTArgs RTArgs;
65936593
emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7488,24 +7488,26 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
74887488

74897489
void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
74907490
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
7491-
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
7492-
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7493-
function_ref<Value *(unsigned int)> CustomMapperCB) {
7494-
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
7495-
DeviceAddrCB, CustomMapperCB);
7491+
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7492+
function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
7493+
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7494+
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
7495+
IsNonContiguous, DeviceAddrCB);
74967496
emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
74977497
}
74987498

74997499
static void
75007500
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
75017501
OpenMPIRBuilder::InsertPointTy AllocaIP,
7502+
OpenMPIRBuilder::TargetDataInfo &Info,
75027503
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
75037504
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
75047505
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
75057506
SmallVectorImpl<Value *> &Args,
75067507
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7507-
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7508-
bool HasNoWait = false) {
7508+
function_ref<Value *(unsigned int)> CustomMapperCB,
7509+
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
7510+
bool HasNoWait) {
75097511
// Generate a function call to the host fallback implementation of the target
75107512
// region. This is called by the host when no offload entry was generated for
75117513
// the target region and when the offloading call fails at runtime.
@@ -7583,7 +7585,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
75837585
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
75847586
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
75857587
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7586-
RTArgs, MapInfo,
7588+
RTArgs, MapInfo, CustomMapperCB,
75877589
/*IsNonContiguous=*/true,
75887590
/*ForEndCall=*/false);
75897591

@@ -7687,12 +7689,14 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
76877689

76887690
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
76897691
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7690-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7692+
InsertPointTy CodeGenIP, TargetDataInfo &Info,
7693+
TargetRegionEntryInfo &EntryInfo,
76917694
const TargetKernelDefaultAttrs &DefaultAttrs,
76927695
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7693-
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7696+
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
76947697
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
76957698
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7699+
function_ref<Value *(unsigned int)> CustomMapperCB,
76967700
SmallVector<DependData> Dependencies, bool HasNowait) {
76977701

76987702
if (!updateToLocation(Loc))
@@ -7707,16 +7711,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
77077711
// and ArgAccessorFuncCB
77087712
if (Error Err = emitTargetOutlinedFunction(
77097713
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7710-
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7714+
OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
77117715
return Err;
77127716

77137717
// If we are not on the target device, then we need to generate code
77147718
// to make a remote call (offload) to the previously outlined function
77157719
// that represents the target region. Do that now.
77167720
if (!Config.isTargetDevice())
7717-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
7718-
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7719-
HasNowait);
7721+
emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
7722+
IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
7723+
CustomMapperCB, Dependencies, HasNowait);
77207724
return Builder.saveIP();
77217725
}
77227726

@@ -8041,9 +8045,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
80418045
OffloadingArgs);
80428046
}
80438047

8044-
Function *OpenMPIRBuilder::emitUserDefinedMapper(
8045-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
8046-
llvm::Value *BeginArg)>
8048+
Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
8049+
function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
8050+
llvm::Value *BeginArg)>
80478051
GenMapInfoCB,
80488052
Type *ElemTy, StringRef FuncName,
80498053
function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -8117,7 +8121,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
81178121
PtrPHI->addIncoming(PtrBegin, HeadBB);
81188122

81198123
// Get map clause information. Fill up the arrays with all mapped variables.
8120-
MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8124+
MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8125+
if (!Info)
8126+
return Info.takeError();
81218127

81228128
// Call the runtime API __tgt_mapper_num_components to get the number of
81238129
// pre-existing components.
@@ -8129,20 +8135,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
81298135
Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
81308136

81318137
// Fill up the runtime mapper handle for all components.
8132-
for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
8138+
for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
81338139
Value *CurBaseArg =
8134-
Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
8140+
Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
81358141
Value *CurBeginArg =
8136-
Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
8137-
Value *CurSizeArg = Info.Sizes[I];
8138-
Value *CurNameArg = Info.Names.size()
8139-
? Info.Names[I]
8142+
Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
8143+
Value *CurSizeArg = Info->Sizes[I];
8144+
Value *CurNameArg = Info->Names.size()
8145+
? Info->Names[I]
81408146
: Constant::getNullValue(Builder.getPtrTy());
81418147

81428148
// Extract the MEMBER_OF field from the map type.
81438149
Value *OriMapType = Builder.getInt64(
81448150
static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8145-
Info.Types[I]));
8151+
Info->Types[I]));
81468152
Value *MemberMapType =
81478153
Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
81488154

@@ -8263,9 +8269,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
82638269

82648270
void OpenMPIRBuilder::emitOffloadingArrays(
82658271
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8266-
TargetDataInfo &Info, bool IsNonContiguous,
8267-
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
8268-
function_ref<Value *(unsigned int)> CustomMapperCB) {
8272+
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
8273+
bool IsNonContiguous,
8274+
function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
82698275

82708276
// Reset the array information.
82718277
Info.clearArrayInfo();

0 commit comments

Comments
 (0)