@@ -1122,29 +1122,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1122
1122
static void getGenericEffectsImpl (
1123
1123
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1124
1124
&effects,
1125
- ValueRange results, const ValueRange inputOperands,
1126
- ValueRange outputOperands) {
1127
- for (auto operand : inputOperands) {
1125
+ LinalgOp linalgOp) {
1126
+ SmallVector<Value> inputOperands = linalgOp. getDpsInputs ();
1127
+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
1128
1128
if (!llvm::isa<MemRefType>(operand.getType ()))
1129
1129
continue ;
1130
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1131
- SideEffects::DefaultResource::get ());
1130
+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1131
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1132
+ /* effectOnFullRegion=*/ true ,
1133
+ SideEffects::DefaultResource::get ());
1134
+ }
1132
1135
}
1133
- for (auto operand : outputOperands) {
1136
+ unsigned inputOperandSize = inputOperands.size ();
1137
+
1138
+ for (auto [index, operand] : llvm::enumerate (linalgOp.getDpsInits ())) {
1134
1139
if (!llvm::isa<MemRefType>(operand.getType ()))
1135
1140
continue ;
1136
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1137
- SideEffects::DefaultResource::get ());
1138
- effects.emplace_back (MemoryEffects::Write::get (), operand,
1141
+ if (linalgOp.payloadUsesValueFromOperand (
1142
+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1143
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1144
+ /* effectOnFullRegion=*/ true ,
1145
+ SideEffects::DefaultResource::get ());
1146
+ }
1147
+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
1148
+ /* effectOnFullRegion=*/ true ,
1139
1149
SideEffects::DefaultResource::get ());
1140
1150
}
1141
1151
}
1142
1152
1143
1153
void GenericOp::getEffects (
1144
1154
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1145
1155
&effects) {
1146
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1147
- getDpsInits ());
1156
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1148
1157
}
1149
1158
1150
1159
LogicalResult GenericOp::verify () { return success (); }
@@ -1492,8 +1501,7 @@ ArrayAttr MapOp::getIndexingMaps() {
1492
1501
void MapOp::getEffects (
1493
1502
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1494
1503
&effects) {
1495
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1496
- getDpsInits ());
1504
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1497
1505
}
1498
1506
1499
1507
// ===----------------------------------------------------------------------===//
@@ -1561,8 +1569,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
1561
1569
void ReduceOp::getEffects (
1562
1570
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1563
1571
&effects) {
1564
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1565
- getDpsInits ());
1572
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1566
1573
}
1567
1574
1568
1575
static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
@@ -1846,8 +1853,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
1846
1853
void TransposeOp::getEffects (
1847
1854
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1848
1855
&effects) {
1849
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1850
- getDpsInits ());
1856
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1851
1857
}
1852
1858
1853
1859
LogicalResult TransposeOp::fold (FoldAdaptor adaptor,
@@ -1984,8 +1990,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
1984
1990
void BroadcastOp::getEffects (
1985
1991
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1986
1992
&effects) {
1987
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1988
- getDpsInits ());
1993
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1989
1994
}
1990
1995
1991
1996
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2513,8 +2518,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2513
2518
void SoftmaxOp::getEffects (
2514
2519
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2515
2520
&effects) {
2516
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2517
- getDpsInits ());
2521
+ for (Value operand : getDpsInputs ()) {
2522
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2523
+ continue ;
2524
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2525
+ /* effectOnFullRegion=*/ true ,
2526
+ SideEffects::DefaultResource::get ());
2527
+ }
2528
+ for (Value operand : getDpsInits ()) {
2529
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2530
+ continue ;
2531
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2532
+ /* effectOnFullRegion=*/ true ,
2533
+ SideEffects::DefaultResource::get ());
2534
+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
2535
+ /* effectOnFullRegion=*/ true ,
2536
+ SideEffects::DefaultResource::get ());
2537
+ }
2518
2538
}
2519
2539
2520
2540
// Helper functions for softmax decomposition.
0 commit comments