Skip to content

[MLIR][OpenMP] Add private clause to omp.parallel #81452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
/*privatizers=*/nullptr);
}

static mlir::omp::SectionOp
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$privatizers);

let regions = (region AnyRegion:$region);

Expand All @@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars),
$reductions, $private_vars, type($private_vars),
$privatizers) attr-dict
}];
let hasVerifier = 1;
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reductions = */ ArrayAttr{},
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
/* private_vars = */ ValueRange(),
/* privatizers = */ nullptr);
{

OpBuilder::InsertionGuard guard(rewriter);
Expand Down
160 changes: 123 additions & 37 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//

ParseResult
parseReductionClause(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
SmallVectorImpl<OpAsmParser::Argument> &privates) {
if (failed(parser.parseOptionalKeyword("reduction")))
return failure();

ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVector<SymbolRefAttr> reductionVec;
unsigned regionArgOffset = regionPrivateArgs.size();

if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(privates.emplace_back()) ||
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();

for (auto [prv, type] : llvm::zip_equal(privates, types)) {
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
argsBegin + regionArgOffset + types.size());
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
symbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}

static void printReductionClause(OpAsmPrinter &p, Operation *op,
ValueRange reductionArgs, ValueRange operands,
TypeRange types, ArrayAttr reductionSymbols) {
p << "reduction(";
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
TypeRange types, ArrayAttr symbols) {
p << clauseName << "(";
llvm::interleaveComma(
llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
[&p](auto t) {
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
p << ") ";
}

static ParseResult
parseParallelRegion(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
static ParseResult parseParallelRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
llvm::SmallVectorImpl<Type> &privateVarsTypes,
ArrayAttr &privatizerSymbols) {
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parseReductionClause(parser, region, operands, types,
reductionSymbols, privates)))
return parser.parseRegion(region, privates);
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
reductionVarTypes, reductionSymbols,
regionPrivateArgs)))
return failure();
}

return parser.parseRegion(region);
if (succeeded(parser.parseOptionalKeyword("private"))) {
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
privateVarsTypes, privatizerSymbols,
regionPrivateArgs)))
return failure();
}

return parser.parseRegion(region, regionPrivateArgs);
}

static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region.front().getArguments(), operands, types,
reductionSymbols);
ValueRange reductionVarOperands,
TypeRange reductionVarTypes,
ArrayAttr reductionSymbols,
ValueRange privateVarOperands,
TypeRange privateVarTypes,
ArrayAttr privatizerSymbols) {
if (reductionSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin,
argsBegin + reductionVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
reductionVarOperands, reductionVarTypes,
reductionSymbols);
}

if (privatizerSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
argsBegin + reductionVarOperands.size() +
privateVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes,
privatizerSymbols);
}

p.printRegion(region, /*printEntryBlockArgs=*/false);
}

Expand Down Expand Up @@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
/*proc_bind_val=*/nullptr);
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
/*privatizers=*/nullptr);
state.addAttributes(attributes);
}

template <typename OpType>
static LogicalResult verifyPrivateVarList(OpType &op) {
auto privateVars = op.getPrivateVars();
auto privatizers = op.getPrivatizersAttr();

if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
return success();

auto numPrivateVars = privateVars.size();
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();

if (numPrivateVars != numPrivatizers)
return op.emitError() << "inconsistent number of private variables and "
"privatizer op symbols, private vars: "
<< numPrivateVars
<< " vs. privatizer op symbols: " << numPrivatizers;

for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
Type varType = std::get<0>(privateVarInfo).getType();
SymbolRefAttr privatizerSym =
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
PrivateClauseOp privatizerOp =
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
privatizerSym);

if (privatizerOp == nullptr)
return op.emitError() << "failed to lookup privatizer op with symbol: '"
<< privatizerSym << "'";

Type privatizerType = privatizerOp.getType();

if (varType != privatizerType)
return op.emitError()
<< "type mismatch between a "
<< (privatizerOp.getDataSharingType() ==
DataSharingClauseType::Private
? "private"
: "firstprivate")
<< " variable and its privatizer op, var type: " << varType
<< " vs. privatizer op type: " << privatizerType;
}

return success();
}

LogicalResult ParallelOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");

if (failed(verifyPrivateVarList(*this)))
return failure();

return verifyReductionVarList(*this, getReductions(), getReductionVars());
}

Expand Down Expand Up @@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,

// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
bool hasReduction = succeeded(
parseReductionClause(parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
succeeded(parseClauseWithRegionArgs(
parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));

if (parser.parseKeyword("for"))
return failure();
Expand Down Expand Up @@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
if (reductionSymbols) {
auto reductionArgs =
region.front().getArguments().drop_front(loopVarTypes.size());
printReductionClause(p, op, reductionArgs, reductionOperands,
reductionTypes, reductionSymbols);
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
reductionOperands, reductionTypes,
reductionSymbols);
}

p << " for ";
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
^bb0(%arg0: f32):
omp.yield(%arg0 : f32)
}

// -----

func.func @private_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}

// -----

func.func @firstprivate_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
} copy {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}

// -----

func.func @undefined_privatizer(%arg0: index) {
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
return
}
Loading