@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
430
430
// Parser, printer and verifier for ReductionVarList
431
431
// ===----------------------------------------------------------------------===//
432
432
433
- ParseResult
434
- parseReductionClause (OpAsmParser &parser, Region ®ion,
435
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
437
- SmallVectorImpl<OpAsmParser::Argument> &privates) {
438
- if (failed (parser.parseOptionalKeyword (" reduction" )))
439
- return failure ();
440
-
433
+ ParseResult parseClauseWithRegionArgs (
434
+ OpAsmParser &parser, Region ®ion,
435
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436
+ SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437
+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
441
438
SmallVector<SymbolRefAttr> reductionVec;
439
+ unsigned regionArgOffset = regionPrivateArgs.size ();
442
440
443
441
if (failed (
444
442
parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren, [&]() {
445
443
if (parser.parseAttribute (reductionVec.emplace_back ()) ||
446
444
parser.parseOperand (operands.emplace_back ()) ||
447
445
parser.parseArrow () ||
448
- parser.parseArgument (privates .emplace_back ()) ||
446
+ parser.parseArgument (regionPrivateArgs .emplace_back ()) ||
449
447
parser.parseColonType (types.emplace_back ()))
450
448
return failure ();
451
449
return success ();
452
450
})))
453
451
return failure ();
454
452
455
- for (auto [prv, type] : llvm::zip_equal (privates, types)) {
453
+ auto *argsBegin = regionPrivateArgs.begin ();
454
+ MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
455
+ argsBegin + regionArgOffset + types.size ());
456
+ for (auto [prv, type] : llvm::zip_equal (argsSubrange, types)) {
456
457
prv.type = type;
457
458
}
458
459
SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
459
- reductionSymbols = ArrayAttr::get (parser.getContext (), reductions);
460
+ symbols = ArrayAttr::get (parser.getContext (), reductions);
460
461
return success ();
461
462
}
462
463
463
- static void printReductionClause (OpAsmPrinter &p, Operation *op,
464
- ValueRange reductionArgs, ValueRange operands,
465
- TypeRange types, ArrayAttr reductionSymbols) {
466
- p << " reduction(" ;
464
+ static void printClauseWithRegionArgs (OpAsmPrinter &p, Operation *op,
465
+ ValueRange argsSubrange,
466
+ StringRef clauseName, ValueRange operands,
467
+ TypeRange types, ArrayAttr symbols) {
468
+ p << clauseName << " (" ;
467
469
llvm::interleaveComma (
468
- llvm::zip_equal (reductionSymbols, operands, reductionArgs, types), p,
469
- [&p](auto t) {
470
+ llvm::zip_equal (symbols, operands, argsSubrange, types), p, [&p](auto t) {
470
471
auto [sym, op, arg, type] = t;
471
472
p << sym << " " << op << " -> " << arg << " : " << type;
472
473
});
473
474
p << " ) " ;
474
475
}
475
476
476
- static ParseResult
477
- parseParallelRegion (OpAsmParser &parser, Region ®ion,
478
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
477
+ static ParseResult parseParallelRegion (
478
+ OpAsmParser &parser, Region ®ion,
479
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
480
+ SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
481
+ llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
482
+ llvm::SmallVectorImpl<Type> &privateVarsTypes,
483
+ ArrayAttr &privatizerSymbols) {
484
+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
480
485
481
- llvm::SmallVector<OpAsmParser::Argument> privates;
482
- if (succeeded (parseReductionClause (parser, region, operands, types,
483
- reductionSymbols, privates)))
484
- return parser.parseRegion (region, privates);
486
+ if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
487
+ if (failed (parseClauseWithRegionArgs (parser, region, reductionVarOperands,
488
+ reductionVarTypes, reductionSymbols,
489
+ regionPrivateArgs)))
490
+ return failure ();
491
+ }
485
492
486
- return parser.parseRegion (region);
493
+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
494
+ if (failed (parseClauseWithRegionArgs (parser, region, privateVarOperands,
495
+ privateVarsTypes, privatizerSymbols,
496
+ regionPrivateArgs)))
497
+ return failure ();
498
+ }
499
+
500
+ return parser.parseRegion (region, regionPrivateArgs);
487
501
}
488
502
489
503
static void printParallelRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
490
- ValueRange operands, TypeRange types,
491
- ArrayAttr reductionSymbols) {
492
- if (reductionSymbols)
493
- printReductionClause (p, op, region.front ().getArguments (), operands, types,
494
- reductionSymbols);
504
+ ValueRange reductionVarOperands,
505
+ TypeRange reductionVarTypes,
506
+ ArrayAttr reductionSymbols,
507
+ ValueRange privateVarOperands,
508
+ TypeRange privateVarTypes,
509
+ ArrayAttr privatizerSymbols) {
510
+ if (reductionSymbols) {
511
+ auto *argsBegin = region.front ().getArguments ().begin ();
512
+ MutableArrayRef argsSubrange (argsBegin,
513
+ argsBegin + reductionVarTypes.size ());
514
+ printClauseWithRegionArgs (p, op, argsSubrange, " reduction" ,
515
+ reductionVarOperands, reductionVarTypes,
516
+ reductionSymbols);
517
+ }
518
+
519
+ if (privatizerSymbols) {
520
+ auto *argsBegin = region.front ().getArguments ().begin ();
521
+ MutableArrayRef argsSubrange (argsBegin + reductionVarOperands.size (),
522
+ argsBegin + reductionVarOperands.size () +
523
+ privateVarTypes.size ());
524
+ printClauseWithRegionArgs (p, op, argsSubrange, " private" ,
525
+ privateVarOperands, privateVarTypes,
526
+ privatizerSymbols);
527
+ }
528
+
495
529
p.printRegion (region, /* printEntryBlockArgs=*/ false );
496
530
}
497
531
@@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
1174
1208
builder, state, /* if_expr_var=*/ nullptr , /* num_threads_var=*/ nullptr ,
1175
1209
/* allocate_vars=*/ ValueRange (), /* allocators_vars=*/ ValueRange (),
1176
1210
/* reduction_vars=*/ ValueRange (), /* reductions=*/ nullptr ,
1177
- /* proc_bind_val=*/ nullptr );
1211
+ /* proc_bind_val=*/ nullptr , /* private_vars=*/ ValueRange (),
1212
+ /* privatizers=*/ nullptr );
1178
1213
state.addAttributes (attributes);
1179
1214
}
1180
1215
1216
+ template <typename OpType>
1217
+ static LogicalResult verifyPrivateVarList (OpType &op) {
1218
+ auto privateVars = op.getPrivateVars ();
1219
+ auto privatizers = op.getPrivatizersAttr ();
1220
+
1221
+ if (privateVars.empty () && (privatizers == nullptr || privatizers.empty ()))
1222
+ return success ();
1223
+
1224
+ auto numPrivateVars = privateVars.size ();
1225
+ auto numPrivatizers = (privatizers == nullptr ) ? 0 : privatizers.size ();
1226
+
1227
+ if (numPrivateVars != numPrivatizers)
1228
+ return op.emitError () << " inconsistent number of private variables and "
1229
+ " privatizer op symbols, private vars: "
1230
+ << numPrivateVars
1231
+ << " vs. privatizer op symbols: " << numPrivatizers;
1232
+
1233
+ for (auto privateVarInfo : llvm::zip_equal (privateVars, privatizers)) {
1234
+ Type varType = std::get<0 >(privateVarInfo).getType ();
1235
+ SymbolRefAttr privatizerSym =
1236
+ std::get<1 >(privateVarInfo).template cast <SymbolRefAttr>();
1237
+ PrivateClauseOp privatizerOp =
1238
+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1239
+ privatizerSym);
1240
+
1241
+ if (privatizerOp == nullptr )
1242
+ return op.emitError () << " failed to lookup privatizer op with symbol: '"
1243
+ << privatizerSym << " '" ;
1244
+
1245
+ Type privatizerType = privatizerOp.getType ();
1246
+
1247
+ if (varType != privatizerType)
1248
+ return op.emitError ()
1249
+ << " type mismatch between a "
1250
+ << (privatizerOp.getDataSharingType () ==
1251
+ DataSharingClauseType::Private
1252
+ ? " private"
1253
+ : " firstprivate" )
1254
+ << " variable and its privatizer op, var type: " << varType
1255
+ << " vs. privatizer op type: " << privatizerType;
1256
+ }
1257
+
1258
+ return success ();
1259
+ }
1260
+
1181
1261
LogicalResult ParallelOp::verify () {
1182
1262
if (getAllocateVars ().size () != getAllocatorsVars ().size ())
1183
1263
return emitError (
1184
1264
" expected equal sizes for allocate and allocator variables" );
1265
+
1266
+ if (failed (verifyPrivateVarList (*this )))
1267
+ return failure ();
1268
+
1185
1269
return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1186
1270
}
1187
1271
@@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region ®ion,
1279
1363
1280
1364
// Parse an optional reduction clause
1281
1365
llvm::SmallVector<OpAsmParser::Argument> privates;
1282
- bool hasReduction = succeeded (
1283
- parseReductionClause (parser, region, reductionOperands, reductionTypes,
1284
- reductionSymbols, privates));
1366
+ bool hasReduction = succeeded (parser.parseOptionalKeyword (" reduction" )) &&
1367
+ succeeded (parseClauseWithRegionArgs (
1368
+ parser, region, reductionOperands, reductionTypes,
1369
+ reductionSymbols, privates));
1285
1370
1286
1371
if (parser.parseKeyword (" for" ))
1287
1372
return failure ();
@@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region ®ion,
1328
1413
if (reductionSymbols) {
1329
1414
auto reductionArgs =
1330
1415
region.front ().getArguments ().drop_front (loopVarTypes.size ());
1331
- printReductionClause (p, op, reductionArgs, reductionOperands,
1332
- reductionTypes, reductionSymbols);
1416
+ printClauseWithRegionArgs (p, op, reductionArgs, " reduction" ,
1417
+ reductionOperands, reductionTypes,
1418
+ reductionSymbols);
1333
1419
}
1334
1420
1335
1421
p << " for " ;
0 commit comments