From 257641aacc6ba24d80d7e55fd85a80bce0708499 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Fri, 12 Jul 2024 18:03:40 +0100 Subject: [PATCH] [OpenMP] Fix use_device_ptr(addr) mappings for Fortran Pointer types Fixes test in https://ontrack-internal.amd.com/browse/SWDEV-471469. --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 2 +- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 44 +++++++++++++------ 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index e31f323fa57b08..8e96a844631e0e 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6374,7 +6374,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData( // Disable TargetData CodeGen on Device pass. if (Config.IsTargetDevice.value_or(false)) { if (BodyGenCB) - Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv)); + Builder.restoreIP(BodyGenCB(CodeGenIP, BodyGenTy::NoPriv)); return Builder.saveIP(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3383aa5336f569..bfc8770832bd11 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2229,14 +2229,20 @@ void collectMapDataFromMapOperands( } } - auto findMapInfo = [&mapData](llvm::Value *val, unsigned &index) { + auto findMapInfo = [&mapData](llvm::Value *val, unsigned &index, + llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) { + bool found = false; index = 0; for (llvm::Value *basePtr : mapData.OriginalValue) { - if (basePtr == val && !mapData.IsAMember[index]) - return true; + if (basePtr == val) { + found = true; + mapData.Types[index] |= + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mapData.DevicePointers[index] = devInfoTy; + } index++; } - return false; + return found; }; // Process useDevPtr(Addr)Operands @@ -2248,16 +2254,12 @@ void collectMapDataFromMapOperands( mlir::Value offloadPtr = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr(); - // Check if map info is already present for this entry. unsigned infoIndex; + llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr); - if (findMapInfo(moduleTranslation.lookupValue(offloadPtr), infoIndex)) { - mapData.Types[infoIndex] |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; - mapData.DevicePointers[infoIndex] = devInfoTy; - } else { - mapData.OriginalValue.push_back( - moduleTranslation.lookupValue(offloadPtr)); + // Check if map info is already present for this entry. + if (!findMapInfo(origValue, infoIndex, devInfoTy)) { + mapData.OriginalValue.push_back(origValue); mapData.Pointers.push_back(mapData.OriginalValue.back()); mapData.IsDeclareTarget.push_back(false); mapData.BasePointers.push_back(mapData.OriginalValue.back()); @@ -2535,7 +2537,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( - llvm::OpenMPIRBuilder::DeviceInfoTy::None); + mapData.DevicePointers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); @@ -2613,7 +2615,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( - mapData.DevicePointers[memberDataIdx]); + llvm::OpenMPIRBuilder::DeviceInfoTy::None); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); @@ -2975,6 +2977,20 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, if (info.DevicePtrInfoMap.empty()) { builder.restoreIP(codeGenIP); SmallVector phis; + + // For device pass, if use_device_ptr(addr) mappings were present, + // we need to link them here before codegen. + unsigned argIndex = 0; + for (size_t i = 0; i < mapData.BasePointers.size(); ++i) { + if (mapData.DevicePointers[i] == + llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer || + mapData.DevicePointers[i] == + llvm::OpenMPIRBuilder::DeviceInfoTy::Address) { + const auto &arg = region.front().getArgument(argIndex); + moduleTranslation.mapValue(arg, mapData.BasePointers[i]); + argIndex++; + } + } llvm::BasicBlock *continuationBlock = convertOmpOpRegions(region, "omp.data.region", builder, moduleTranslation, bodyGenStatus, &phis);