Skip to content

Commit a2aba2c

Browse files
committed
[MLIR][OpenMP] Add private clause to omp.parallel
Extends the `omp.parallel` op by adding a `private` clause to model [first]private variables. This uses the `omp.private` op to map privatized variables to their corresponding privatizers. Example `omp.private` op with `private` variable: ``` omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) { ^bb0(%arg1: !llvm.ptr): // ... use %arg1 ... omp.terminator } ``` Whether the variable is private or firstprivate is determined by the attributes of the corresponding `omp.private` op.
1 parent 070fad4 commit a2aba2c

File tree

7 files changed

+199
-11
lines changed

7 files changed

+199
-11
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26402640
? nullptr
26412641
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
26422642
reductionDeclSymbols),
2643-
procBindKindAttr);
2643+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2644+
/*privatizers=*/nullptr);
26442645
}
26452646

26462647
static mlir::omp::SectionOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
270270
Variadic<AnyType>:$allocators_vars,
271271
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
272272
OptionalAttr<SymbolRefArrayAttr>:$reductions,
273-
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
273+
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
274+
Variadic<AnyType>:$private_vars,
275+
OptionalAttr<SymbolRefArrayAttr>:$privatizers);
274276

275277
let regions = (region AnyRegion:$region);
276278

@@ -291,6 +293,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
291293
$allocators_vars, type($allocators_vars)
292294
) `)`
293295
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
296+
| `private` `(`
297+
custom<PrivateVarList>(
298+
$private_vars, type($private_vars), $privatizers
299+
) `)`
294300
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
295301
}];
296302
let hasVerifier = 1;

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
420420
/* allocators_vars = */ llvm::SmallVector<Value>{},
421421
/* reduction_vars = */ llvm::SmallVector<Value>{},
422422
/* reductions = */ ArrayAttr{},
423-
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
423+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
424+
/* private_vars = */ ValueRange(),
425+
/* privatizers = */ nullptr);
424426
{
425427

426428
OpBuilder::InsertionGuard guard(rewriter);

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
491491
ArrayAttr reductionSymbols) {
492492
if (reductionSymbols)
493493
printReductionClause(p, op, region, operands, types, reductionSymbols);
494-
p.printRegion(region, /*printEntryBlockArgs=*/false);
494+
p.printRegion(region, /*printEntryBlockArgs=*/true);
495495
}
496496

497497
/// reduction-entry-list ::= reduction-entry
@@ -1057,14 +1057,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
10571057
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
10581058
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
10591059
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1060-
/*proc_bind_val=*/nullptr);
1060+
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1061+
/*privatizers=*/nullptr);
10611062
state.addAttributes(attributes);
10621063
}
10631064

1065+
static LogicalResult verifyPrivateVarList(ParallelOp &op) {
1066+
auto privateVars = op.getPrivateVars();
1067+
auto privatizers = op.getPrivatizersAttr();
1068+
1069+
if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1070+
return success();
1071+
1072+
auto numPrivateVars = privateVars.size();
1073+
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1074+
1075+
if (numPrivateVars != numPrivatizers)
1076+
return op.emitError() << "inconsistent number of private variables and "
1077+
"privatizer op symbols, private vars: "
1078+
<< numPrivateVars
1079+
<< " vs. privatizer op symbols: " << numPrivatizers;
1080+
1081+
for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
1082+
Type varType = std::get<0>(privateVarInfo).getType();
1083+
SymbolRefAttr privatizerSym =
1084+
std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
1085+
PrivateClauseOp privatizerOp =
1086+
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1087+
privatizerSym);
1088+
1089+
if (privatizerOp == nullptr)
1090+
return op.emitError() << "failed to lookup privatizer op with symbol: '"
1091+
<< privatizerSym << "'";
1092+
1093+
Type privatizerType = privatizerOp.getType();
1094+
1095+
if (varType != privatizerType)
1096+
return op.emitError()
1097+
<< "type mismatch between a "
1098+
<< (privatizerOp.getDataSharingType() ==
1099+
DataSharingClauseType::Private
1100+
? "private"
1101+
: "firstprivate")
1102+
<< " variable and its privatizer op, var type: " << varType
1103+
<< " vs. privatizer op type: " << privatizerType;
1104+
}
1105+
1106+
return success();
1107+
}
1108+
10641109
LogicalResult ParallelOp::verify() {
10651110
if (getAllocateVars().size() != getAllocatorsVars().size())
10661111
return emitError(
10671112
"expected equal sizes for allocate and allocator variables");
1113+
1114+
if (failed(verifyPrivateVarList(*this)))
1115+
return failure();
1116+
10681117
return verifyReductionVarList(*this, getReductions(), getReductionVars());
10691118
}
10701119

@@ -1729,6 +1778,68 @@ LogicalResult PrivateClauseOp::verify() {
17291778
return success();
17301779
}
17311780

1781+
static ParseResult parsePrivateVarList(
1782+
OpAsmParser &parser,
1783+
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
1784+
llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privatizersAttr) {
1785+
SymbolRefAttr privatizerSym;
1786+
OpAsmParser::UnresolvedOperand privateArg;
1787+
OpAsmParser::UnresolvedOperand regionArg;
1788+
Type argType;
1789+
1790+
SmallVector<SymbolRefAttr> privatizersVec;
1791+
1792+
auto parsePrivatizers = [&]() -> ParseResult {
1793+
// @privatizer %var -> %region_arg : type
1794+
if (parser.parseAttribute(privatizerSym) ||
1795+
parser.parseOperand(privateArg) || parser.parseArrow() ||
1796+
parser.parseOperand(regionArg) || parser.parseColon() ||
1797+
parser.parseType(argType)) {
1798+
return failure();
1799+
}
1800+
1801+
privatizersVec.push_back(privatizerSym);
1802+
privateVarsOperands.push_back(privateArg);
1803+
privateVarsTypes.push_back(argType);
1804+
1805+
return success();
1806+
};
1807+
1808+
if (parser.parseCommaSeparatedList(parsePrivatizers))
1809+
return failure();
1810+
1811+
SmallVector<Attribute> privatizers(privatizersVec.begin(),
1812+
privatizersVec.end());
1813+
privatizersAttr = ArrayAttr::get(parser.getContext(), privatizers);
1814+
1815+
return success();
1816+
}
1817+
1818+
static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
1819+
OperandRange privateVars,
1820+
TypeRange privateVarTypes,
1821+
std::optional<ArrayAttr> privatizersAttr) {
1822+
unsigned argIndex = 0;
1823+
assert(privatizersAttr);
1824+
1825+
for (const auto &priateVarArgInfo :
1826+
llvm::zip(*privatizersAttr, privateVars, op->getRegion(0).getArguments(),
1827+
privateVarTypes)) {
1828+
assert(privatizersAttr);
1829+
printer << std::get<0>(priateVarArgInfo) << " "
1830+
<< std::get<1>(priateVarArgInfo) << " -> ";
1831+
1832+
std::get<2>(priateVarArgInfo)
1833+
.printAsOperand(printer.getStream(), OpPrintingFlags());
1834+
1835+
printer << " : " << std::get<3>(priateVarArgInfo);
1836+
1837+
++argIndex;
1838+
if (argIndex < privateVars.size())
1839+
printer << ", ";
1840+
}
1841+
}
1842+
17321843
#define GET_ATTRDEF_CLASSES
17331844
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
17341845

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,3 +1801,62 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
18011801
^bb0(%arg0: f32):
18021802
omp.yield(%arg0 : f32)
18031803
}
1804+
1805+
// -----
1806+
1807+
func.func @private_type_mismatch(%arg0: index) {
1808+
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1809+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1810+
^bb0(%arg2: !llvm.ptr):
1811+
omp.terminator
1812+
}
1813+
1814+
return
1815+
}
1816+
1817+
omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
1818+
^bb0(%arg0: !llvm.ptr):
1819+
omp.yield(%arg0 : !llvm.ptr)
1820+
}
1821+
1822+
// -----
1823+
1824+
func.func @firstprivate_type_mismatch(%arg0: index) {
1825+
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
1826+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1827+
^bb0(%arg2: !llvm.ptr):
1828+
omp.terminator
1829+
}
1830+
1831+
return
1832+
}
1833+
1834+
omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
1835+
^bb0(%arg0: !llvm.ptr):
1836+
omp.yield(%arg0 : !llvm.ptr)
1837+
} copy {
1838+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1839+
omp.yield(%arg0 : !llvm.ptr)
1840+
}
1841+
1842+
// -----
1843+
1844+
func.func @undefined_privatizer(%arg0: index) {
1845+
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
1846+
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
1847+
^bb0(%arg2: !llvm.ptr):
1848+
omp.terminator
1849+
}
1850+
1851+
return
1852+
}
1853+
1854+
// -----
1855+
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
1856+
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
1857+
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
1858+
^bb0(%arg2: !llvm.ptr):
1859+
omp.terminator
1860+
}) : (!llvm.ptr) -> ()
1861+
return
1862+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
5959
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6060
"omp.parallel"(%num_threads, %data_var, %data_var) ({
6161
omp.terminator
62-
}) {operandSegmentSizes = array<i32: 0,1,1,1,0>} : (i32, memref<i32>, memref<i32>) -> ()
62+
}) {operandSegmentSizes = array<i32: 0,1,1,1,0,0>} : (i32, memref<i32>, memref<i32>) -> ()
6363

6464
// CHECK: omp.barrier
6565
omp.barrier
@@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
6868
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6969
"omp.parallel"(%if_cond, %data_var, %data_var) ({
7070
omp.terminator
71-
}) {operandSegmentSizes = array<i32: 1,0,1,1,0>} : (i1, memref<i32>, memref<i32>) -> ()
71+
}) {operandSegmentSizes = array<i32: 1,0,1,1,0,0>} : (i1, memref<i32>, memref<i32>) -> ()
7272

7373
// test without allocate
7474
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
7575
"omp.parallel"(%if_cond, %num_threads) ({
7676
omp.terminator
77-
}) {operandSegmentSizes = array<i32: 1,1,0,0,0>} : (i1, i32) -> ()
77+
}) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (i1, i32) -> ()
7878

7979
omp.terminator
80-
}) {operandSegmentSizes = array<i32: 1,1,1,1,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
80+
}) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
8181

8282
// test with multiple parameters for single variadic argument
8383
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
8484
"omp.parallel" (%data_var, %data_var) ({
8585
omp.terminator
86-
}) {operandSegmentSizes = array<i32: 0,0,1,1,0>} : (memref<i32>, memref<i32>) -> ()
86+
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
8787

8888
return
8989
}

mlir/test/Dialect/OpenMP/roundtrip.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
22

3+
// CHECK-LABEL: parallel_op_privatizers
4+
func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
5+
// CHECK: omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr)
6+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
7+
^bb0(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
8+
omp.terminator
9+
}
10+
return
11+
}
12+
313
// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
414
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
515
// CHECK: ^bb0(%arg0: {{.*}}):
@@ -18,4 +28,3 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
1828
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
1929
omp.yield(%arg0 : !llvm.ptr)
2030
}
21-

0 commit comments

Comments
 (0)