Skip to content

Commit

Permalink
[OpenMP] Fix use_device_ptr(addr) mappings for Fortran Pointer types
Browse files Browse the repository at this point in the history
  • Loading branch information
TIFitis committed Jul 12, 2024
1 parent 6c280be commit 257641a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
2 changes: 1 addition & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6374,7 +6374,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
// Disable TargetData CodeGen on Device pass.
if (Config.IsTargetDevice.value_or(false)) {
if (BodyGenCB)
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv));
return Builder.saveIP();
}

Expand Down
44 changes: 30 additions & 14 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2229,14 +2229,20 @@ void collectMapDataFromMapOperands(
}
}

auto findMapInfo = [&mapData](llvm::Value *val, unsigned &index) {
auto findMapInfo = [&mapData](llvm::Value *val, unsigned &index,
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
bool found = false;
index = 0;
for (llvm::Value *basePtr : mapData.OriginalValue) {
if (basePtr == val && !mapData.IsAMember[index])
return true;
if (basePtr == val) {
found = true;
mapData.Types[index] |=
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
mapData.DevicePointers[index] = devInfoTy;
}
index++;
}
return false;
return found;
};

// Process useDevPtr(Addr)Operands
Expand All @@ -2248,16 +2254,12 @@ void collectMapDataFromMapOperands(
mlir::Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();

// Check if map info is already present for this entry.
unsigned infoIndex;
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);

if (findMapInfo(moduleTranslation.lookupValue(offloadPtr), infoIndex)) {
mapData.Types[infoIndex] |=
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
mapData.DevicePointers[infoIndex] = devInfoTy;
} else {
mapData.OriginalValue.push_back(
moduleTranslation.lookupValue(offloadPtr));
// Check if map info is already present for this entry.
if (!findMapInfo(origValue, infoIndex, devInfoTy)) {
mapData.OriginalValue.push_back(origValue);
mapData.Pointers.push_back(mapData.OriginalValue.back());
mapData.IsDeclareTarget.push_back(false);
mapData.BasePointers.push_back(mapData.OriginalValue.back());
Expand Down Expand Up @@ -2535,7 +2537,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
mapData.DevicePointers[mapDataIndex]);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
Expand Down Expand Up @@ -2613,7 +2615,7 @@ static void processMapMembersWithParent(

combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
mapData.DevicePointers[memberDataIdx]);
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));

Expand Down Expand Up @@ -2975,6 +2977,20 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
SmallVector<llvm::PHINode *> phis;

// For device pass, if use_device_ptr(addr) mappings were present,
// we need to link them here before codegen.
unsigned argIndex = 0;
for (size_t i = 0; i < mapData.BasePointers.size(); ++i) {
if (mapData.DevicePointers[i] ==
llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
mapData.DevicePointers[i] ==
llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
const auto &arg = region.front().getArgument(argIndex);
moduleTranslation.mapValue(arg, mapData.BasePointers[i]);
argIndex++;
}
}
llvm::BasicBlock *continuationBlock =
convertOmpOpRegions(region, "omp.data.region", builder,
moduleTranslation, bodyGenStatus, &phis);
Expand Down

0 comments on commit 257641a

Please sign in to comment.