Skip to content

Commit ee17955

Browse files
authored
[MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (#120994)
This patch adds the mapper field to the omp.map.info op. Depends on #117046.
1 parent 74cb1f9 commit ee17955

File tree

8 files changed

+32
-6
lines changed

8 files changed

+32
-6
lines changed

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
125125
llvm::ArrayRef<mlir::Value> members,
126126
mlir::ArrayAttr membersIndex, uint64_t mapType,
127127
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
128-
bool partialMap) {
128+
bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
129129
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
130130
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
131131
retTy = baseAddr.getType();
@@ -144,6 +144,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
144144
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
145145
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
146146
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
147+
mapperId,
147148
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
148149
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
149150
return op;

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
116116
llvm::ArrayRef<mlir::Value> members,
117117
mlir::ArrayAttr membersIndex, uint64_t mapType,
118118
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
119-
bool partialMap = false);
119+
bool partialMap = false,
120+
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());
120121

121122
void insertChildMapInfoIntoParent(
122123
Fortran::lower::AbstractConverter &converter,

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class MapInfoFinalizationPass
184184
/*members=*/mlir::SmallVector<mlir::Value>{},
185185
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
186186
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
187+
/*mapperId*/ mlir::FlatSymbolRefAttr(),
187188
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
188189
mlir::omp::VariableCaptureKind::ByRef),
189190
/*name=*/builder.getStringAttr(""),
@@ -329,7 +330,8 @@ class MapInfoFinalizationPass
329330
builder.getIntegerAttr(
330331
builder.getIntegerType(64, false),
331332
getDescriptorMapType(op.getMapType().value_or(0), target)),
332-
op.getMapCaptureTypeAttr(), op.getNameAttr(),
333+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
334+
op.getNameAttr(),
333335
/*partial_map=*/builder.getBoolAttr(false));
334336
op.replaceAllUsesWith(newDescParentMapOp.getResult());
335337
op->erase();
@@ -623,6 +625,7 @@ class MapInfoFinalizationPass
623625
/*members=*/mlir::ValueRange{},
624626
/*members_index=*/mlir::ArrayAttr{},
625627
/*bounds=*/bounds, op.getMapTypeAttr(),
628+
/*mapperId*/ mlir::FlatSymbolRefAttr(),
626629
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
627630
mlir::omp::VariableCaptureKind::ByRef),
628631
builder.getStringAttr(op.getNameAttr().strref() + "." +

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass
9191
/*bounds=*/ValueRange{},
9292
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
9393
mapTypeTo),
94+
/*mapperId*/ mlir::FlatSymbolRefAttr(),
9495
builder.getAttr<omp::VariableCaptureKindAttr>(
9596
omp::VariableCaptureKind::ByRef),
9697
StringAttr(), builder.getBoolAttr(false));

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10231023
OptionalAttr<IndexListArrayAttr>:$members_index,
10241024
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
10251025
OptionalAttr<UI64Attr>:$map_type,
1026+
OptionalAttr<FlatSymbolRefAttr>:$mapper_id,
10261027
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
10271028
OptionalAttr<StrAttr>:$name,
10281029
DefaultValuedAttr<BoolAttr, "false">:$partial_map);
@@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10761077
- 'map_type': OpenMP map type for this map capture, for example: from, to and
10771078
always. It's a bitfield composed of the OpenMP runtime flags stored in
10781079
OpenMPOffloadMappingFlags.
1080+
- 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to
1081+
specify a user defined mapper to be used for mapping.
10791082
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
10801083
this can affect how the variable is lowered.
10811084
- `name`: Holds the name of variable as specified in user clause (including bounds).
@@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10871090
`var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
10881091
oilist(
10891092
`var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
1093+
| `mapper` `(` $mapper_id `)`
10901094
| `map_clauses` `(` custom<MapClause>($map_type) `)`
10911095
| `capture` `(` custom<CaptureType>($map_capture_type) `)`
10921096
| `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
16321632

16331633
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
16341634
}
1635-
} else {
1635+
1636+
if (mapInfoOp.getMapperId() &&
1637+
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1638+
mapInfoOp, mapInfoOp.getMapperIdAttr())) {
1639+
return emitError(op->getLoc(), "invalid mapper id");
1640+
}
1641+
} else if (!isa<DeclareMapperInfoOp>(op)) {
16361642
emitError(op->getLoc(), "map argument is not a map entry operation");
16371643
}
16381644
}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,3 +2849,13 @@ func.func @missing_workshare(%idx : index) {
28492849
^bb0(%arg0: !llvm.ptr):
28502850
omp.terminator
28512851
}
2852+
2853+
// -----
2854+
llvm.func @invalid_mapper(%0 : !llvm.ptr) {
2855+
%1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2856+
// expected-error @below {{invalid mapper id}}
2857+
omp.target_data map_entries(%1 : !llvm.ptr) {
2858+
omp.terminator
2859+
}
2860+
llvm.return
2861+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
25462546
// CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64
25472547
// CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64
25482548
// CHECK: %[[BOUNDS1:.*]] = omp.map.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64)
2549-
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
2549+
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
25502550
%6 = llvm.mlir.constant(9 : index) : i64
25512551
%7 = llvm.mlir.constant(1 : index) : i64
25522552
%8 = llvm.mlir.constant(2 : index) : i64
25532553
%9 = llvm.mlir.constant(2 : index) : i64
25542554
%10 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64)
2555-
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
2555+
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
25562556

25572557
// CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
25582558
omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {

0 commit comments

Comments
 (0)