@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
430
430
// Parser, printer and verifier for ReductionVarList
431
431
// ===----------------------------------------------------------------------===//
432
432
433
- ParseResult
434
- parseReductionClause (OpAsmParser &parser, Region ®ion,
435
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
437
- SmallVectorImpl<OpAsmParser::Argument> &privates) {
438
- if (failed (parser.parseOptionalKeyword (" reduction" )))
439
- return failure ();
440
-
433
+ ParseResult parseClauseWithRegionArgs (
434
+ OpAsmParser &parser, Region ®ion,
435
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
436
+ SmallVectorImpl<Type> &types, ArrayAttr &symbols,
437
+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
441
438
SmallVector<SymbolRefAttr> reductionVec;
439
+ unsigned regionArgOffset = regionPrivateArgs.size ();
442
440
443
441
if (failed (
444
442
parser.parseCommaSeparatedList (OpAsmParser::Delimiter::Paren, [&]() {
445
443
if (parser.parseAttribute (reductionVec.emplace_back ()) ||
446
444
parser.parseOperand (operands.emplace_back ()) ||
447
445
parser.parseArrow () ||
448
- parser.parseArgument (privates .emplace_back ()) ||
446
+ parser.parseArgument (regionPrivateArgs .emplace_back ()) ||
449
447
parser.parseColonType (types.emplace_back ()))
450
448
return failure ();
451
449
return success ();
452
450
})))
453
451
return failure ();
454
452
455
- for (auto [prv, type] : llvm::zip_equal (privates, types)) {
453
+ auto *argsBegin = regionPrivateArgs.begin ();
454
+ MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
455
+ argsBegin + regionArgOffset + types.size ());
456
+ for (auto [prv, type] : llvm::zip_equal (argsSubrange, types)) {
456
457
prv.type = type;
457
458
}
458
459
SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
459
- reductionSymbols = ArrayAttr::get (parser.getContext (), reductions);
460
+ symbols = ArrayAttr::get (parser.getContext (), reductions);
460
461
return success ();
461
462
}
462
463
463
- static void printReductionClause (OpAsmPrinter &p, Operation *op,
464
- ValueRange reductionArgs, ValueRange operands,
465
- TypeRange types, ArrayAttr reductionSymbols) {
466
- p << " reduction(" ;
464
+ static void printClauseWithRegionArgs (OpAsmPrinter &p, Operation *op,
465
+ ValueRange argsSubrange,
466
+ StringRef clauseName, ValueRange operands,
467
+ TypeRange types, ArrayAttr symbols) {
468
+ p << clauseName << " (" ;
467
469
llvm::interleaveComma (
468
- llvm::zip_equal (reductionSymbols, operands, reductionArgs, types), p,
469
- [&p](auto t) {
470
+ llvm::zip_equal (symbols, operands, argsSubrange, types), p, [&p](auto t) {
470
471
auto [sym, op, arg, type] = t;
471
472
p << sym << " " << op << " -> " << arg << " : " << type;
472
473
});
473
474
p << " ) " ;
474
475
}
475
476
476
- static ParseResult
477
- parseParallelRegion (OpAsmParser &parser, Region ®ion,
478
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
479
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
477
+ static ParseResult parseParallelRegion (
478
+ OpAsmParser &parser, Region ®ion,
479
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
480
+ SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
481
+ llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
482
+ llvm::SmallVectorImpl<Type> &privateVarsTypes,
483
+ ArrayAttr &privatizerSymbols) {
484
+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
480
485
481
- llvm::SmallVector<OpAsmParser::Argument> privates;
482
- if (succeeded (parseReductionClause (parser, region, operands, types,
483
- reductionSymbols, privates)))
484
- return parser.parseRegion (region, privates);
486
+ if (succeeded (parser.parseOptionalKeyword (" reduction" ))) {
487
+ if (failed (parseClauseWithRegionArgs (parser, region, reductionVarOperands,
488
+ reductionVarTypes, reductionSymbols,
489
+ regionPrivateArgs)))
490
+ return failure ();
491
+ }
485
492
486
- return parser.parseRegion (region);
493
+ if (succeeded (parser.parseOptionalKeyword (" private" ))) {
494
+ if (failed (parseClauseWithRegionArgs (parser, region, privateVarOperands,
495
+ privateVarsTypes, privatizerSymbols,
496
+ regionPrivateArgs)))
497
+ return failure ();
498
+ }
499
+
500
+ return parser.parseRegion (region, regionPrivateArgs);
487
501
}
488
502
489
503
static void printParallelRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
490
- ValueRange operands, TypeRange types,
491
- ArrayAttr reductionSymbols) {
492
- if (reductionSymbols)
493
- printReductionClause (p, op, region.front ().getArguments (), operands, types,
494
- reductionSymbols);
504
+ ValueRange reductionVarOperands,
505
+ TypeRange reductionVarTypes,
506
+ ArrayAttr reductionSymbols,
507
+ ValueRange privateVarOperands,
508
+ TypeRange privateVarTypes,
509
+ ArrayAttr privatizerSymbols) {
510
+ if (reductionSymbols) {
511
+ auto *argsBegin = region.front ().getArguments ().begin ();
512
+ MutableArrayRef argsSubrange (argsBegin,
513
+ argsBegin + reductionVarTypes.size ());
514
+ printClauseWithRegionArgs (p, op, argsSubrange, " reduction" ,
515
+ reductionVarOperands, reductionVarTypes,
516
+ reductionSymbols);
517
+ }
518
+
519
+ if (privatizerSymbols) {
520
+ auto *argsBegin = region.front ().getArguments ().begin ();
521
+ MutableArrayRef argsSubrange (argsBegin + reductionVarOperands.size (),
522
+ argsBegin + reductionVarOperands.size () +
523
+ privateVarTypes.size ());
524
+ printClauseWithRegionArgs (p, op, argsSubrange, " private" ,
525
+ privateVarOperands, privateVarTypes,
526
+ privatizerSymbols);
527
+ }
528
+
495
529
p.printRegion (region, /* printEntryBlockArgs=*/ false );
496
530
}
497
531
@@ -1008,9 +1042,8 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
1008
1042
}
1009
1043
1010
1044
if (always || close || implicit) {
1011
- return emitError (
1012
- op->getLoc (),
1013
- " present, mapper and iterator map type modifiers are permitted" );
1045
+ return emitError (op->getLoc (), " present, mapper and iterator map "
1046
+ " type modifiers are permitted" );
1014
1047
}
1015
1048
1016
1049
to ? updateToVars.insert (updateVar) : updateFromVars.insert (updateVar);
@@ -1070,14 +1103,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
1070
1103
builder, state, /* if_expr_var=*/ nullptr , /* num_threads_var=*/ nullptr ,
1071
1104
/* allocate_vars=*/ ValueRange (), /* allocators_vars=*/ ValueRange (),
1072
1105
/* reduction_vars=*/ ValueRange (), /* reductions=*/ nullptr ,
1073
- /* proc_bind_val=*/ nullptr );
1106
+ /* proc_bind_val=*/ nullptr , /* private_vars=*/ ValueRange (),
1107
+ /* privatizers=*/ nullptr );
1074
1108
state.addAttributes (attributes);
1075
1109
}
1076
1110
1111
+ static LogicalResult verifyPrivateVarList (ParallelOp &op) {
1112
+ auto privateVars = op.getPrivateVars ();
1113
+ auto privatizers = op.getPrivatizersAttr ();
1114
+
1115
+ if (privateVars.empty () && (privatizers == nullptr || privatizers.empty ()))
1116
+ return success ();
1117
+
1118
+ auto numPrivateVars = privateVars.size ();
1119
+ auto numPrivatizers = (privatizers == nullptr ) ? 0 : privatizers.size ();
1120
+
1121
+ if (numPrivateVars != numPrivatizers)
1122
+ return op.emitError () << " inconsistent number of private variables and "
1123
+ " privatizer op symbols, private vars: "
1124
+ << numPrivateVars
1125
+ << " vs. privatizer op symbols: " << numPrivatizers;
1126
+
1127
+ for (auto privateVarInfo : llvm::zip (privateVars, privatizers)) {
1128
+ Type varType = std::get<0 >(privateVarInfo).getType ();
1129
+ SymbolRefAttr privatizerSym =
1130
+ std::get<1 >(privateVarInfo).cast <SymbolRefAttr>();
1131
+ PrivateClauseOp privatizerOp =
1132
+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1133
+ privatizerSym);
1134
+
1135
+ if (privatizerOp == nullptr )
1136
+ return op.emitError () << " failed to lookup privatizer op with symbol: '"
1137
+ << privatizerSym << " '" ;
1138
+
1139
+ Type privatizerType = privatizerOp.getType ();
1140
+
1141
+ if (varType != privatizerType)
1142
+ return op.emitError ()
1143
+ << " type mismatch between a "
1144
+ << (privatizerOp.getDataSharingType () ==
1145
+ DataSharingClauseType::Private
1146
+ ? " private"
1147
+ : " firstprivate" )
1148
+ << " variable and its privatizer op, var type: " << varType
1149
+ << " vs. privatizer op type: " << privatizerType;
1150
+ }
1151
+
1152
+ return success ();
1153
+ }
1154
+
1077
1155
LogicalResult ParallelOp::verify () {
1078
1156
if (getAllocateVars ().size () != getAllocatorsVars ().size ())
1079
1157
return emitError (
1080
1158
" expected equal sizes for allocate and allocator variables" );
1159
+
1160
+ if (failed (verifyPrivateVarList (*this )))
1161
+ return failure ();
1162
+
1081
1163
return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1082
1164
}
1083
1165
@@ -1111,8 +1193,8 @@ LogicalResult TeamsOp::verify() {
1111
1193
return emitError (" expected num_teams upper bound to be defined if the "
1112
1194
" lower bound is defined" );
1113
1195
if (numTeamsLowerBound.getType () != numTeamsUpperBound.getType ())
1114
- return emitError (
1115
- " expected num_teams upper bound and lower bound to be the same type" );
1196
+ return emitError (" expected num_teams upper bound and lower bound to be "
1197
+ " the same type" );
1116
1198
}
1117
1199
1118
1200
// Check for allocate clause restrictions
@@ -1174,9 +1256,10 @@ parseWsLoop(OpAsmParser &parser, Region ®ion,
1174
1256
1175
1257
// Parse an optional reduction clause
1176
1258
llvm::SmallVector<OpAsmParser::Argument> privates;
1177
- bool hasReduction = succeeded (
1178
- parseReductionClause (parser, region, reductionOperands, reductionTypes,
1179
- reductionSymbols, privates));
1259
+ bool hasReduction = succeeded (parser.parseOptionalKeyword (" reduction" )) &&
1260
+ succeeded (parseClauseWithRegionArgs (
1261
+ parser, region, reductionOperands, reductionTypes,
1262
+ reductionSymbols, privates));
1180
1263
1181
1264
if (parser.parseKeyword (" for" ))
1182
1265
return failure ();
@@ -1223,8 +1306,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region ®ion,
1223
1306
if (reductionSymbols) {
1224
1307
auto reductionArgs =
1225
1308
region.front ().getArguments ().drop_front (loopVarTypes.size ());
1226
- printReductionClause (p, op, reductionArgs, reductionOperands,
1227
- reductionTypes, reductionSymbols);
1309
+ printClauseWithRegionArgs (p, op, reductionArgs, " reduction" ,
1310
+ reductionOperands, reductionTypes,
1311
+ reductionSymbols);
1228
1312
}
1229
1313
1230
1314
p << " for " ;
@@ -1464,9 +1548,9 @@ LogicalResult TaskLoopOp::verify() {
1464
1548
}
1465
1549
1466
1550
if (getGrainSize () && getNumTasks ()) {
1467
- return emitError (
1468
- " the grainsize clause and num_tasks clause are mutually exclusive and "
1469
- " may not appear on the same taskloop directive" );
1551
+ return emitError (" the grainsize clause and num_tasks clause are mutually "
1552
+ " exclusive and "
1553
+ " may not appear on the same taskloop directive" );
1470
1554
}
1471
1555
return success ();
1472
1556
}
@@ -1535,7 +1619,8 @@ LogicalResult OrderedOp::verify() {
1535
1619
}
1536
1620
1537
1621
LogicalResult OrderedRegionOp::verify () {
1538
- // TODO: The code generation for ordered simd directive is not supported yet.
1622
+ // TODO: The code generation for ordered simd directive is not supported
1623
+ // yet.
1539
1624
if (getSimd ())
1540
1625
return failure ();
1541
1626
0 commit comments