Skip to content

Commit acf8cdf

Browse files
committed
[MLIR][OpenMP] Add private clause to omp.parallel
Extends the `omp.parallel` op by adding a `private` clause to model [first]private variables. This uses the `omp.private` op to map privatized variables to their corresponding privatizers. Example `omp.private` op with `private` variable: ``` omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) { // ... use %arg1 ... omp.terminator } ``` Whether the variable is private or firstprivate is determined by the attributes of the corresponding `omp.private` op.
1 parent 1ecbab5 commit acf8cdf

File tree

7 files changed

+255
-67
lines changed

7 files changed

+255
-67
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26402640
? nullptr
26412641
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
26422642
reductionDeclSymbols),
2643-
procBindKindAttr);
2643+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2644+
/*privatizers=*/nullptr);
26442645
}
26452646

26462647
static mlir::omp::SectionOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
276276
Variadic<AnyType>:$allocators_vars,
277277
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
278278
OptionalAttr<SymbolRefArrayAttr>:$reductions,
279-
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
279+
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
280+
Variadic<AnyType>:$private_vars,
281+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
280282

281283
let regions = (region AnyRegion:$region);
282284

@@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
297299
$allocators_vars, type($allocators_vars)
298300
) `)`
299301
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
300-
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
302+
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars),
303+
$reductions, $private_vars, type($private_vars),
304+
$privatizers) attr-dict
301305
}];
302306
let hasVerifier = 1;
303307
}

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
450450
/* allocators_vars = */ llvm::SmallVector<Value>{},
451451
/* reduction_vars = */ llvm::SmallVector<Value>{},
452452
/* reductions = */ ArrayAttr{},
453-
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
453+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
454+
/* private_vars = */ ValueRange(),
455+
/* privatizers = */ nullptr);
454456
{
455457

456458
OpBuilder::InsertionGuard guard(rewriter);

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 123 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
430430
// Parser, printer and verifier for ReductionVarList
431431
//===----------------------------------------------------------------------===//
432432

433-
ParseResult
434-
parseReductionClause(OpAsmParser &parser, Region &region,
435-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
437-
SmallVectorImpl<OpAsmParser::Argument> &privates) {
438-
if (failed(parser.parseOptionalKeyword("reduction")))
439-
return failure();
440-
433+
ParseResult parseClauseWithRegionArgs(
434+
OpAsmParser &parser, Region &region,
435+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436+
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437+
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
441438
SmallVector<SymbolRefAttr> reductionVec;
439+
unsigned regionArgOffset = regionPrivateArgs.size();
442440

443441
if (failed(
444442
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
445443
if (parser.parseAttribute(reductionVec.emplace_back()) ||
446444
parser.parseOperand(operands.emplace_back()) ||
447445
parser.parseArrow() ||
448-
parser.parseArgument(privates.emplace_back()) ||
446+
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
449447
parser.parseColonType(types.emplace_back()))
450448
return failure();
451449
return success();
452450
})))
453451
return failure();
454452

455-
for (auto [prv, type] : llvm::zip_equal(privates, types)) {
453+
auto *argsBegin = regionPrivateArgs.begin();
454+
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
455+
argsBegin + regionArgOffset + types.size());
456+
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
456457
prv.type = type;
457458
}
458459
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
459-
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
460+
symbols = ArrayAttr::get(parser.getContext(), reductions);
460461
return success();
461462
}
462463

463-
static void printReductionClause(OpAsmPrinter &p, Operation *op,
464-
ValueRange reductionArgs, ValueRange operands,
465-
TypeRange types, ArrayAttr reductionSymbols) {
466-
p << "reduction(";
464+
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
465+
ValueRange argsSubrange,
466+
StringRef clauseName, ValueRange operands,
467+
TypeRange types, ArrayAttr symbols) {
468+
p << clauseName << "(";
467469
llvm::interleaveComma(
468-
llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
469-
[&p](auto t) {
470+
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
470471
auto [sym, op, arg, type] = t;
471472
p << sym << " " << op << " -> " << arg << " : " << type;
472473
});
473474
p << ") ";
474475
}
475476

476-
static ParseResult
477-
parseParallelRegion(OpAsmParser &parser, Region &region,
478-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479-
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
477+
static ParseResult parseParallelRegion(
478+
OpAsmParser &parser, Region &region,
479+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
480+
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
481+
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
482+
llvm::SmallVectorImpl<Type> &privateVarsTypes,
483+
ArrayAttr &privatizerSymbols) {
484+
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
480485

481-
llvm::SmallVector<OpAsmParser::Argument> privates;
482-
if (succeeded(parseReductionClause(parser, region, operands, types,
483-
reductionSymbols, privates)))
484-
return parser.parseRegion(region, privates);
486+
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
487+
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
488+
reductionVarTypes, reductionSymbols,
489+
regionPrivateArgs)))
490+
return failure();
491+
}
485492

486-
return parser.parseRegion(region);
493+
if (succeeded(parser.parseOptionalKeyword("private"))) {
494+
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
495+
privateVarsTypes, privatizerSymbols,
496+
regionPrivateArgs)))
497+
return failure();
498+
}
499+
500+
return parser.parseRegion(region, regionPrivateArgs);
487501
}
488502

489503
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
490-
ValueRange operands, TypeRange types,
491-
ArrayAttr reductionSymbols) {
492-
if (reductionSymbols)
493-
printReductionClause(p, op, region.front().getArguments(), operands, types,
494-
reductionSymbols);
504+
ValueRange reductionVarOperands,
505+
TypeRange reductionVarTypes,
506+
ArrayAttr reductionSymbols,
507+
ValueRange privateVarOperands,
508+
TypeRange privateVarTypes,
509+
ArrayAttr privatizerSymbols) {
510+
if (reductionSymbols) {
511+
auto *argsBegin = region.front().getArguments().begin();
512+
MutableArrayRef argsSubrange(argsBegin,
513+
argsBegin + reductionVarTypes.size());
514+
printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
515+
reductionVarOperands, reductionVarTypes,
516+
reductionSymbols);
517+
}
518+
519+
if (privatizerSymbols) {
520+
auto *argsBegin = region.front().getArguments().begin();
521+
MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
522+
argsBegin + reductionVarOperands.size() +
523+
privateVarTypes.size());
524+
printClauseWithRegionArgs(p, op, argsSubrange, "private",
525+
privateVarOperands, privateVarTypes,
526+
privatizerSymbols);
527+
}
528+
495529
p.printRegion(region, /*printEntryBlockArgs=*/false);
496530
}
497531

@@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
11741208
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
11751209
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
11761210
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1177-
/*proc_bind_val=*/nullptr);
1211+
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1212+
/*privatizers=*/nullptr);
11781213
state.addAttributes(attributes);
11791214
}
11801215

1216+
template <typename OpType>
1217+
static LogicalResult verifyPrivateVarList(OpType &op) {
1218+
auto privateVars = op.getPrivateVars();
1219+
auto privatizers = op.getPrivatizersAttr();
1220+
1221+
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1222+
return success();
1223+
1224+
auto numPrivateVars = privateVars.size();
1225+
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1226+
1227+
if (numPrivateVars != numPrivatizers)
1228+
return op.emitError() << "inconsistent number of private variables and "
1229+
"privatizer op symbols, private vars: "
1230+
<< numPrivateVars
1231+
<< " vs. privatizer op symbols: " << numPrivatizers;
1232+
1233+
for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1234+
Type varType = std::get<0>(privateVarInfo).getType();
1235+
SymbolRefAttr privatizerSym =
1236+
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
1237+
PrivateClauseOp privatizerOp =
1238+
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1239+
privatizerSym);
1240+
1241+
if (privatizerOp == nullptr)
1242+
return op.emitError() << "failed to lookup privatizer op with symbol: '"
1243+
<< privatizerSym << "'";
1244+
1245+
Type privatizerType = privatizerOp.getType();
1246+
1247+
if (varType != privatizerType)
1248+
return op.emitError()
1249+
<< "type mismatch between a "
1250+
<< (privatizerOp.getDataSharingType() ==
1251+
DataSharingClauseType::Private
1252+
? "private"
1253+
: "firstprivate")
1254+
<< " variable and its privatizer op, var type: " << varType
1255+
<< " vs. privatizer op type: " << privatizerType;
1256+
}
1257+
1258+
return success();
1259+
}
1260+
11811261
LogicalResult ParallelOp::verify() {
11821262
if (getAllocateVars().size() != getAllocatorsVars().size())
11831263
return emitError(
11841264
"expected equal sizes for allocate and allocator variables");
1265+
1266+
if (failed(verifyPrivateVarList(*this)))
1267+
return failure();
1268+
11851269
return verifyReductionVarList(*this, getReductions(), getReductionVars());
11861270
}
11871271

@@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,
12791363

12801364
// Parse an optional reduction clause
12811365
llvm::SmallVector<OpAsmParser::Argument> privates;
1282-
bool hasReduction = succeeded(
1283-
parseReductionClause(parser, region, reductionOperands, reductionTypes,
1284-
reductionSymbols, privates));
1366+
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
1367+
succeeded(parseClauseWithRegionArgs(
1368+
parser, region, reductionOperands, reductionTypes,
1369+
reductionSymbols, privates));
12851370

12861371
if (parser.parseKeyword("for"))
12871372
return failure();
@@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
13281413
if (reductionSymbols) {
13291414
auto reductionArgs =
13301415
region.front().getArguments().drop_front(loopVarTypes.size());
1331-
printReductionClause(p, op, reductionArgs, reductionOperands,
1332-
reductionTypes, reductionSymbols);
1416+
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1417+
reductionOperands, reductionTypes,
1418+
reductionSymbols);
13331419
}
13341420

13351421
p << " for ";

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
18651865
^bb0(%arg0: f32):
18661866
omp.yield(%arg0 : f32)
18671867
}
1868+
1869+
// -----
1870+
1871+
func.func @private_type_mismatch(%arg0: index) {
1872+
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1873+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1874+
omp.terminator
1875+
}
1876+
1877+
return
1878+
}
1879+
1880+
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
1881+
^bb0(%arg0: !llvm.ptr):
1882+
omp.yield(%arg0 : !llvm.ptr)
1883+
}
1884+
1885+
// -----
1886+
1887+
func.func @firstprivate_type_mismatch(%arg0: index) {
1888+
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1889+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1890+
omp.terminator
1891+
}
1892+
1893+
return
1894+
}
1895+
1896+
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
1897+
^bb0(%arg0: !llvm.ptr):
1898+
omp.yield(%arg0 : !llvm.ptr)
1899+
} copy {
1900+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1901+
omp.yield(%arg0 : !llvm.ptr)
1902+
}
1903+
1904+
// -----
1905+
1906+
func.func @undefined_privatizer(%arg0: index) {
1907+
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
1908+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1909+
omp.terminator
1910+
}
1911+
1912+
return
1913+
}
1914+
1915+
// -----
1916+
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
1917+
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
1918+
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
1919+
^bb0(%arg2: !llvm.ptr):
1920+
omp.terminator
1921+
}) : (!llvm.ptr) -> ()
1922+
return
1923+
}

0 commit comments

Comments
 (0)