@@ -1103,29 +1103,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1103
1103
static void getGenericEffectsImpl (
1104
1104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1105
1105
&effects,
1106
- ValueRange results, const ValueRange inputOperands,
1107
- ValueRange outputOperands) {
1108
- for (auto operand : inputOperands) {
1106
+ LinalgOp linalgOp) {
1107
+ SmallVector<Value> inputOperands = linalgOp. getDpsInputs ();
1108
+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
1109
1109
if (!llvm::isa<MemRefType>(operand.getType ()))
1110
1110
continue ;
1111
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1112
- SideEffects::DefaultResource::get ());
1111
+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1112
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1113
+ /* effectOnFullRegion=*/ true ,
1114
+ SideEffects::DefaultResource::get ());
1115
+ }
1113
1116
}
1114
- for (auto operand : outputOperands) {
1117
+ unsigned inputOperandSize = inputOperands.size ();
1118
+
1119
+ for (auto [index, operand] : llvm::enumerate (linalgOp.getDpsInits ())) {
1115
1120
if (!llvm::isa<MemRefType>(operand.getType ()))
1116
1121
continue ;
1117
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1118
- SideEffects::DefaultResource::get ());
1119
- effects.emplace_back (MemoryEffects::Write::get (), operand,
1122
+ if (linalgOp.payloadUsesValueFromOperand (
1123
+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1124
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
1125
+ /* effectOnFullRegion=*/ true ,
1126
+ SideEffects::DefaultResource::get ());
1127
+ }
1128
+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
1129
+ /* effectOnFullRegion=*/ true ,
1120
1130
SideEffects::DefaultResource::get ());
1121
1131
}
1122
1132
}
1123
1133
1124
1134
void GenericOp::getEffects (
1125
1135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1126
1136
&effects) {
1127
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1128
- getDpsInits ());
1137
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1129
1138
}
1130
1139
1131
1140
LogicalResult GenericOp::verify () { return success (); }
@@ -1473,8 +1482,7 @@ ArrayAttr MapOp::getIndexingMaps() {
1473
1482
void MapOp::getEffects (
1474
1483
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1475
1484
&effects) {
1476
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1477
- getDpsInits ());
1485
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1478
1486
}
1479
1487
1480
1488
// ===----------------------------------------------------------------------===//
@@ -1542,8 +1550,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
1542
1550
void ReduceOp::getEffects (
1543
1551
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1544
1552
&effects) {
1545
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1546
- getDpsInits ());
1553
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1547
1554
}
1548
1555
1549
1556
static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
@@ -1827,8 +1834,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
1827
1834
void TransposeOp::getEffects (
1828
1835
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1829
1836
&effects) {
1830
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1831
- getDpsInits ());
1837
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1832
1838
}
1833
1839
1834
1840
LogicalResult TransposeOp::fold (FoldAdaptor adaptor,
@@ -1965,8 +1971,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
1965
1971
void BroadcastOp::getEffects (
1966
1972
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1967
1973
&effects) {
1968
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1969
- getDpsInits ());
1974
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1970
1975
}
1971
1976
1972
1977
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2494,8 +2499,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2494
2499
void SoftmaxOp::getEffects (
2495
2500
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2496
2501
&effects) {
2497
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2498
- getDpsInits ());
2502
+ for (Value operand : getDpsInputs ()) {
2503
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2504
+ continue ;
2505
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2506
+ /* effectOnFullRegion=*/ true ,
2507
+ SideEffects::DefaultResource::get ());
2508
+ }
2509
+ for (Value operand : getDpsInits ()) {
2510
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2511
+ continue ;
2512
+ effects.emplace_back (MemoryEffects::Read::get (), operand, /* stage=*/ 0 ,
2513
+ /* effectOnFullRegion=*/ true ,
2514
+ SideEffects::DefaultResource::get ());
2515
+ effects.emplace_back (MemoryEffects::Write::get (), operand, /* stage=*/ 0 ,
2516
+ /* effectOnFullRegion=*/ true ,
2517
+ SideEffects::DefaultResource::get ());
2518
+ }
2499
2519
}
2500
2520
2501
2521
// Helper functions for softmax decomposition.
0 commit comments