Skip to content

Commit 9d7a874

Browse files
committed
[MLIR][OpenMP] Add private clause to omp.parallel
Extends the `omp.parallel` op by adding a `private` clause to model [first]private variables. This uses the `omp.private` op to map privatized variables to their corresponding privatizers. Example `omp.private` op with `private` variable: ``` omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) { // ... use %arg1 ... omp.terminator } ``` Whether the variable is private or firstprivate is determined by the attributes of the corresponding `omp.private` op.
1 parent bd2f7bb commit 9d7a874

File tree

8 files changed

+275
-76
lines changed

8 files changed

+275
-76
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26402640
? nullptr
26412641
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
26422642
reductionDeclSymbols),
2643-
procBindKindAttr);
2643+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2644+
/*privatizers=*/nullptr);
26442645
}
26452646

26462647
static mlir::omp::SectionOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
270270
Variadic<AnyType>:$allocators_vars,
271271
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
272272
OptionalAttr<SymbolRefArrayAttr>:$reductions,
273-
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
273+
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
274+
Variadic<AnyType>:$private_vars,
275+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
274276

275277
let regions = (region AnyRegion:$region);
276278

@@ -291,7 +293,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
291293
$allocators_vars, type($allocators_vars)
292294
) `)`
293295
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
294-
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
296+
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions, $private_vars, type($private_vars), $privatizers) attr-dict
295297
}];
296298
let hasVerifier = 1;
297299
}

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
450450
/* allocators_vars = */ llvm::SmallVector<Value>{},
451451
/* reduction_vars = */ llvm::SmallVector<Value>{},
452452
/* reductions = */ ArrayAttr{},
453-
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
453+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
454+
/* private_vars = */ ValueRange(),
455+
/* privatizers = */ nullptr);
454456
{
455457

456458
OpBuilder::InsertionGuard guard(rewriter);

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 131 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
430430
// Parser, printer and verifier for ReductionVarList
431431
//===----------------------------------------------------------------------===//
432432

433-
ParseResult
434-
parseReductionClause(OpAsmParser &parser, Region &region,
435-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
437-
SmallVectorImpl<OpAsmParser::Argument> &privates) {
438-
if (failed(parser.parseOptionalKeyword("reduction")))
439-
return failure();
440-
433+
ParseResult parseClauseWithRegionArgs(
434+
OpAsmParser &parser, Region &region,
435+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436+
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437+
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
441438
SmallVector<SymbolRefAttr> reductionVec;
439+
unsigned regionArgOffset = regionPrivateArgs.size();
442440

443441
if (failed(
444442
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
445443
if (parser.parseAttribute(reductionVec.emplace_back()) ||
446444
parser.parseOperand(operands.emplace_back()) ||
447445
parser.parseArrow() ||
448-
parser.parseArgument(privates.emplace_back()) ||
446+
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
449447
parser.parseColonType(types.emplace_back()))
450448
return failure();
451449
return success();
452450
})))
453451
return failure();
454452

455-
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
453+
auto *argsBegin = regionPrivateArgs.begin();
454+
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
455+
argsBegin + regionArgOffset + types.size());
456+
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
456457
prv.type = type;
457458
}
458459
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
459-
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
460+
symbols = ArrayAttr::get(parser.getContext(), reductions);
460461
return success();
461462
}
462463

463-
static void printReductionClause(OpAsmPrinter &p, Operation *op,
464-
ValueRange reductionArgs, ValueRange operands,
465-
TypeRange types, ArrayAttr reductionSymbols) {
466-
p << "reduction(";
464+
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
465+
ValueRange argsSubrange,
466+
StringRef clauseName, ValueRange operands,
467+
TypeRange types, ArrayAttr symbols) {
468+
p << clauseName << "(";
467469
llvm::interleaveComma(
468-
llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
469-
[&p](auto t) {
470+
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
470471
auto [sym, op, arg, type] = t;
471472
p << sym << " " << op << " -> " << arg << " : " << type;
472473
});
473474
p << ") ";
474475
}
475476

476-
static ParseResult
477-
parseParallelRegion(OpAsmParser &parser, Region &region,
478-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
477+
static ParseResult parseParallelRegion(
478+
OpAsmParser &parser, Region &region,
479+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
480+
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
481+
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
482+
llvm::SmallVectorImpl<Type> &privateVarsTypes,
483+
ArrayAttr &privatizerSymbols) {
484+
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
480485

481-
llvm::SmallVector<OpAsmParser::Argument> privates;
482-
if (succeeded(parseReductionClause(parser, region, operands, types,
483-
reductionSymbols, privates)))
484-
return parser.parseRegion(region, privates);
486+
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
487+
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
488+
reductionVarTypes, reductionSymbols,
489+
regionPrivateArgs)))
490+
return failure();
491+
}
485492

486-
return parser.parseRegion(region);
493+
if (succeeded(parser.parseOptionalKeyword("private"))) {
494+
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
495+
privateVarsTypes, privatizerSymbols,
496+
regionPrivateArgs)))
497+
return failure();
498+
}
499+
500+
return parser.parseRegion(region, regionPrivateArgs);
487501
}
488502

489503
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
490-
ValueRange operands, TypeRange types,
491-
ArrayAttr reductionSymbols) {
492-
if (reductionSymbols)
493-
printReductionClause(p, op, region.front().getArguments(), operands, types,
494-
reductionSymbols);
504+
ValueRange reductionVarOperands,
505+
TypeRange reductionVarTypes,
506+
ArrayAttr reductionSymbols,
507+
ValueRange privateVarOperands,
508+
TypeRange privateVarTypes,
509+
ArrayAttr privatizerSymbols) {
510+
if (reductionSymbols) {
511+
auto *argsBegin = region.front().getArguments().begin();
512+
MutableArrayRef argsSubrange(argsBegin,
513+
argsBegin + reductionVarTypes.size());
514+
printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
515+
reductionVarOperands, reductionVarTypes,
516+
reductionSymbols);
517+
}
518+
519+
if (privatizerSymbols) {
520+
auto *argsBegin = region.front().getArguments().begin();
521+
MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
522+
argsBegin + reductionVarOperands.size() +
523+
privateVarTypes.size());
524+
printClauseWithRegionArgs(p, op, argsSubrange, "private",
525+
privateVarOperands, privateVarTypes,
526+
privatizerSymbols);
527+
}
528+
495529
p.printRegion(region, /*printEntryBlockArgs=*/false);
496530
}
497531

@@ -1008,9 +1042,8 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
10081042
}
10091043

10101044
if (always || close || implicit) {
1011-
return emitError(
1012-
op->getLoc(),
1013-
"present, mapper and iterator map type modifiers are permitted");
1045+
return emitError(op->getLoc(), "present, mapper and iterator map "
1046+
"type modifiers are permitted");
10141047
}
10151048

10161049
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
@@ -1070,14 +1103,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
10701103
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
10711104
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
10721105
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1073-
/*proc_bind_val=*/nullptr);
1106+
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1107+
/*privatizers=*/nullptr);
10741108
state.addAttributes(attributes);
10751109
}
10761110

1111+
static LogicalResult verifyPrivateVarList(ParallelOp &op) {
1112+
auto privateVars = op.getPrivateVars();
1113+
auto privatizers = op.getPrivatizersAttr();
1114+
1115+
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1116+
return success();
1117+
1118+
auto numPrivateVars = privateVars.size();
1119+
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1120+
1121+
if (numPrivateVars != numPrivatizers)
1122+
return op.emitError() << "inconsistent number of private variables and "
1123+
"privatizer op symbols, private vars: "
1124+
<< numPrivateVars
1125+
<< " vs. privatizer op symbols: " << numPrivatizers;
1126+
1127+
for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
1128+
Type varType = std::get<0>(privateVarInfo).getType();
1129+
SymbolRefAttr privatizerSym =
1130+
std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
1131+
PrivateClauseOp privatizerOp =
1132+
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1133+
privatizerSym);
1134+
1135+
if (privatizerOp == nullptr)
1136+
return op.emitError() << "failed to lookup privatizer op with symbol: '"
1137+
<< privatizerSym << "'";
1138+
1139+
Type privatizerType = privatizerOp.getType();
1140+
1141+
if (varType != privatizerType)
1142+
return op.emitError()
1143+
<< "type mismatch between a "
1144+
<< (privatizerOp.getDataSharingType() ==
1145+
DataSharingClauseType::Private
1146+
? "private"
1147+
: "firstprivate")
1148+
<< " variable and its privatizer op, var type: " << varType
1149+
<< " vs. privatizer op type: " << privatizerType;
1150+
}
1151+
1152+
return success();
1153+
}
1154+
10771155
LogicalResult ParallelOp::verify() {
10781156
if (getAllocateVars().size() != getAllocatorsVars().size())
10791157
return emitError(
10801158
"expected equal sizes for allocate and allocator variables");
1159+
1160+
if (failed(verifyPrivateVarList(*this)))
1161+
return failure();
1162+
10811163
return verifyReductionVarList(*this, getReductions(), getReductionVars());
10821164
}
10831165

@@ -1111,8 +1193,8 @@ LogicalResult TeamsOp::verify() {
11111193
return emitError("expected num_teams upper bound to be defined if the "
11121194
"lower bound is defined");
11131195
if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1114-
return emitError(
1115-
"expected num_teams upper bound and lower bound to be the same type");
1196+
return emitError("expected num_teams upper bound and lower bound to be "
1197+
"the same type");
11161198
}
11171199

11181200
// Check for allocate clause restrictions
@@ -1174,9 +1256,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,
11741256

11751257
// Parse an optional reduction clause
11761258
llvm::SmallVector<OpAsmParser::Argument> privates;
1177-
bool hasReduction = succeeded(
1178-
parseReductionClause(parser, region, reductionOperands, reductionTypes,
1179-
reductionSymbols, privates));
1259+
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
1260+
succeeded(parseClauseWithRegionArgs(
1261+
parser, region, reductionOperands, reductionTypes,
1262+
reductionSymbols, privates));
11801263

11811264
if (parser.parseKeyword("for"))
11821265
return failure();
@@ -1223,8 +1306,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
12231306
if (reductionSymbols) {
12241307
auto reductionArgs =
12251308
region.front().getArguments().drop_front(loopVarTypes.size());
1226-
printReductionClause(p, op, reductionArgs, reductionOperands,
1227-
reductionTypes, reductionSymbols);
1309+
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1310+
reductionOperands, reductionTypes,
1311+
reductionSymbols);
12281312
}
12291313

12301314
p << " for ";
@@ -1464,9 +1548,9 @@ LogicalResult TaskLoopOp::verify() {
14641548
}
14651549

14661550
if (getGrainSize() && getNumTasks()) {
1467-
return emitError(
1468-
"the grainsize clause and num_tasks clause are mutually exclusive and "
1469-
"may not appear on the same taskloop directive");
1551+
return emitError("the grainsize clause and num_tasks clause are mutually "
1552+
"exclusive and "
1553+
"may not appear on the same taskloop directive");
14701554
}
14711555
return success();
14721556
}
@@ -1535,7 +1619,8 @@ LogicalResult OrderedOp::verify() {
15351619
}
15361620

15371621
LogicalResult OrderedRegionOp::verify() {
1538-
// TODO: The code generation for ordered simd directive is not supported yet.
1622+
// TODO: The code generation for ordered simd directive is not supported
1623+
// yet.
15391624
if (getSimd())
15401625
return failure();
15411626

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,3 +1809,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
18091809
^bb0(%arg0: f32):
18101810
omp.yield(%arg0 : f32)
18111811
}
1812+
1813+
// -----
1814+
1815+
func.func @private_type_mismatch(%arg0: index) {
1816+
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1817+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1818+
omp.terminator
1819+
}
1820+
1821+
return
1822+
}
1823+
1824+
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
1825+
^bb0(%arg0: !llvm.ptr):
1826+
omp.yield(%arg0 : !llvm.ptr)
1827+
}
1828+
1829+
// -----
1830+
1831+
func.func @firstprivate_type_mismatch(%arg0: index) {
1832+
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1833+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1834+
omp.terminator
1835+
}
1836+
1837+
return
1838+
}
1839+
1840+
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
1841+
^bb0(%arg0: !llvm.ptr):
1842+
omp.yield(%arg0 : !llvm.ptr)
1843+
} copy {
1844+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1845+
omp.yield(%arg0 : !llvm.ptr)
1846+
}
1847+
1848+
// -----
1849+
1850+
func.func @undefined_privatizer(%arg0: index) {
1851+
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
1852+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1853+
omp.terminator
1854+
}
1855+
1856+
return
1857+
}
1858+
1859+
// -----
1860+
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
1861+
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
1862+
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
1863+
^bb0(%arg2: !llvm.ptr):
1864+
omp.terminator
1865+
}) : (!llvm.ptr) -> ()
1866+
return
1867+
}

0 commit comments

Comments
 (0)