Skip to content

Commit 295474b

Browse files
committed
[Flang][OpenMP] Remove use of non reference values from MapInfoOp
This patch removes the `val` field from the MapInfoOp. Previously when lowering TargetOp, the bounds information for the BoxValues were also being mapped. Instead these ops are now duplicated inside the target region to prevent mapping of non reference typed values.
1 parent ea47887 commit 295474b

File tree

11 files changed

+221
-228
lines changed

11 files changed

+221
-228
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "flang/Semantics/tools.h"
2929
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3030
#include "mlir/Dialect/SCF/IR/SCF.h"
31+
#include "mlir/Transforms/RegionUtils.h"
3132
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3233
#include "llvm/Support/CommandLine.h"
3334

@@ -1709,26 +1710,22 @@ static mlir::omp::MapInfoOp
17091710
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
17101711
mlir::Value baseAddr, std::stringstream &name,
17111712
mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
1712-
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
1713-
bool isVal = false) {
1714-
mlir::Value val, varPtr, varPtrPtr;
1713+
mlir::omp::VariableCaptureKind mapCaptureType,
1714+
mlir::Type retTy) {
1715+
mlir::Value varPtr, varPtrPtr;
17151716
mlir::TypeAttr varType;
17161717

17171718
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
17181719
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
17191720
retTy = baseAddr.getType();
17201721
}
17211722

1722-
if (isVal)
1723-
val = baseAddr;
1724-
else
1725-
varPtr = baseAddr;
1726-
1727-
if (auto ptrType = llvm::dyn_cast<mlir::omp::PointerLikeType>(retTy))
1728-
varType = mlir::TypeAttr::get(ptrType.getElementType());
1723+
varPtr = baseAddr;
1724+
varType = mlir::TypeAttr::get(
1725+
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
17291726

17301727
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
1731-
loc, retTy, val, varPtr, varType, varPtrPtr, bounds,
1728+
loc, retTy, varPtr, varType, varPtrPtr, bounds,
17321729
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
17331730
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
17341731
builder.getStringAttr(name.str()));
@@ -2489,21 +2486,27 @@ static void genBodyOfTargetOp(
24892486
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
24902487
mlir::Region &region = targetOp.getRegion();
24912488

2492-
firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
2489+
auto *regionBlock =
2490+
firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
24932491

24942492
unsigned argIndex = 0;
2495-
unsigned blockArgsIndex = mapSymbols.size();
2496-
2497-
// The block arguments contain the map_operands followed by the bounds in
2498-
// order. This returns a vector containing the next 'n' block arguments for
2499-
// the bounds.
2500-
auto extractBoundArgs = [&](auto n) {
2501-
llvm::SmallVector<mlir::Value> argExtents;
2502-
while (n--) {
2503-
argExtents.push_back(fir::getBase(region.getArgument(blockArgsIndex)));
2504-
blockArgsIndex++;
2493+
2494+
// Clones the `bounds` placing them inside the target region and returns them.
2495+
auto cloneBound = [&](mlir::Value bound) {
2496+
if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
2497+
mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
2498+
regionBlock->push_back(clonedOp);
2499+
return clonedOp->getResult(0);
25052500
}
2506-
return argExtents;
2501+
TODO(converter.getCurrentLocation(),
2502+
"target map clause operand unsupported bound type");
2503+
};
2504+
2505+
auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
2506+
llvm::SmallVector<mlir::Value> clonedBounds;
2507+
for (mlir::Value bound : bounds)
2508+
clonedBounds.emplace_back(cloneBound(bound));
2509+
return clonedBounds;
25072510
};
25082511

25092512
// Bind the symbols to their corresponding block arguments.
@@ -2512,34 +2515,31 @@ static void genBodyOfTargetOp(
25122515
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
25132516
extVal.match(
25142517
[&](const fir::BoxValue &v) {
2515-
converter.bindSymbol(
2516-
*sym,
2517-
fir::BoxValue(arg, extractBoundArgs(v.getLBounds().size()),
2518-
v.getExplicitParameters(), v.getExplicitExtents()));
2518+
converter.bindSymbol(*sym,
2519+
fir::BoxValue(arg, cloneBounds(v.getLBounds()),
2520+
v.getExplicitParameters(),
2521+
v.getExplicitExtents()));
25192522
},
25202523
[&](const fir::MutableBoxValue &v) {
25212524
converter.bindSymbol(
2522-
*sym,
2523-
fir::MutableBoxValue(arg, extractBoundArgs(v.getLBounds().size()),
2524-
v.getMutableProperties()));
2525+
*sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
2526+
v.getMutableProperties()));
25252527
},
25262528
[&](const fir::ArrayBoxValue &v) {
25272529
converter.bindSymbol(
2528-
*sym,
2529-
fir::ArrayBoxValue(arg, extractBoundArgs(v.getExtents().size()),
2530-
extractBoundArgs(v.getLBounds().size()),
2531-
v.getSourceBox()));
2530+
*sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
2531+
cloneBounds(v.getLBounds()),
2532+
v.getSourceBox()));
25322533
},
25332534
[&](const fir::CharArrayBoxValue &v) {
25342535
converter.bindSymbol(
2535-
*sym,
2536-
fir::CharArrayBoxValue(arg, extractBoundArgs(1).front(),
2537-
extractBoundArgs(v.getExtents().size()),
2538-
extractBoundArgs(v.getLBounds().size())));
2536+
*sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
2537+
cloneBounds(v.getExtents()),
2538+
cloneBounds(v.getLBounds())));
25392539
},
25402540
[&](const fir::CharBoxValue &v) {
2541-
converter.bindSymbol(
2542-
*sym, fir::CharBoxValue(arg, extractBoundArgs(1).front()));
2541+
converter.bindSymbol(*sym,
2542+
fir::CharBoxValue(arg, cloneBound(v.getLen())));
25432543
},
25442544
[&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
25452545
[&](const auto &) {
@@ -2549,6 +2549,55 @@ static void genBodyOfTargetOp(
25492549
argIndex++;
25502550
}
25512551

2552+
// Check if cloning the bounds introduced any dependency on the outer region.
2553+
// If so, then either clone them as well if they are MemoryEffectFree, or else
2554+
// copy them to a new temporary and add them to the map and block_argument
2555+
// lists and replace their uses with the new temporary.
2556+
llvm::SetVector<mlir::Value> valuesDefinedAbove;
2557+
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
2558+
while (!valuesDefinedAbove.empty()) {
2559+
for (mlir::Value val : valuesDefinedAbove) {
2560+
mlir::Operation *valOp = val.getDefiningOp();
2561+
if (mlir::isMemoryEffectFree(valOp)) {
2562+
mlir::Operation *clonedOp = valOp->clone();
2563+
regionBlock->push_front(clonedOp);
2564+
val.replaceUsesWithIf(
2565+
clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
2566+
return use.getOwner()->getBlock() == regionBlock;
2567+
});
2568+
} else {
2569+
auto savedIP = firOpBuilder.getInsertionPoint();
2570+
firOpBuilder.setInsertionPointAfter(valOp);
2571+
auto copyVal =
2572+
firOpBuilder.createTemporary(val.getLoc(), val.getType());
2573+
firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
2574+
2575+
llvm::SmallVector<mlir::Value> bounds;
2576+
std::stringstream name;
2577+
firOpBuilder.setInsertionPoint(targetOp);
2578+
mlir::Value mapOp = createMapInfoOp(
2579+
firOpBuilder, copyVal.getLoc(), copyVal, name, bounds,
2580+
static_cast<
2581+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2582+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
2583+
mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
2584+
targetOp.getMapOperandsMutable().append(mapOp);
2585+
mlir::Value clonedValArg =
2586+
region.addArgument(copyVal.getType(), copyVal.getLoc());
2587+
firOpBuilder.setInsertionPointToStart(regionBlock);
2588+
auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(),
2589+
clonedValArg);
2590+
val.replaceUsesWithIf(
2591+
loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
2592+
return use.getOwner()->getBlock() == regionBlock;
2593+
});
2594+
firOpBuilder.setInsertionPoint(regionBlock, savedIP);
2595+
}
2596+
}
2597+
valuesDefinedAbove.clear();
2598+
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
2599+
}
2600+
25522601
// Insert dummy instruction to remember the insertion position. The
25532602
// marker will be deleted since there are not uses.
25542603
// In the HLFIR flow there are hlfir.declares inserted above while
@@ -2671,53 +2720,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
26712720
};
26722721
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
26732722

2674-
// Add the bounds and extents for box values to mapOperands
2675-
auto addMapInfoForBounds = [&](const auto &bounds) {
2676-
for (auto &val : bounds) {
2677-
mapSymLocs.push_back(val.getLoc());
2678-
mapSymTypes.push_back(val.getType());
2679-
2680-
llvm::SmallVector<mlir::Value> bounds;
2681-
std::stringstream name;
2682-
2683-
mlir::Value mapOp = createMapInfoOp(
2684-
converter.getFirOpBuilder(), val.getLoc(), val, name, bounds,
2685-
static_cast<
2686-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2687-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
2688-
mlir::omp::VariableCaptureKind::ByCopy, val.getType(), true);
2689-
mapOperands.push_back(mapOp);
2690-
}
2691-
};
2692-
2693-
for (const Fortran::semantics::Symbol *sym : mapSymbols) {
2694-
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
2695-
extVal.match(
2696-
[&](const fir::BoxValue &v) { addMapInfoForBounds(v.getLBounds()); },
2697-
[&](const fir::MutableBoxValue &v) {
2698-
addMapInfoForBounds(v.getLBounds());
2699-
},
2700-
[&](const fir::ArrayBoxValue &v) {
2701-
addMapInfoForBounds(v.getExtents());
2702-
addMapInfoForBounds(v.getLBounds());
2703-
},
2704-
[&](const fir::CharArrayBoxValue &v) {
2705-
llvm::SmallVector<mlir::Value> len;
2706-
len.push_back(v.getLen());
2707-
addMapInfoForBounds(len);
2708-
addMapInfoForBounds(v.getExtents());
2709-
addMapInfoForBounds(v.getLBounds());
2710-
},
2711-
[&](const fir::CharBoxValue &v) {
2712-
llvm::SmallVector<mlir::Value> len;
2713-
len.push_back(v.getLen());
2714-
addMapInfoForBounds(len);
2715-
},
2716-
[&](const auto &) {
2717-
// Nothing to do for non-box values.
2718-
});
2719-
}
2720-
27212723
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
27222724
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
27232725
nowaitAttr, mapOperands);

flang/test/Lower/OpenMP/FIR/array-bounds.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
!ALL: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
1717
!ALL: %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>, !fir.array<10xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
1818
!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ITER]] : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
19-
!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>, index, index) {
19+
!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
2020

2121
subroutine read_write_section()
2222
integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
@@ -64,7 +64,7 @@ end subroutine assumed_shape_array
6464
!ALL: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
6565
!ALL: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
6666
!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
67-
!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, index) {
67+
!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
6868

6969
subroutine assumed_size_array(arr_read_write)
7070
integer, intent(inout) :: arr_read_write(*)

0 commit comments

Comments
 (0)