Skip to content

Commit 72f715b

Browse files
AmosLewisAmosLewis
authored andcommitted
[MLIR] Fold aten select and copy pattern
1 parent c86f46b commit 72f715b

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1313
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1414
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1516

1617
using namespace mlir;
1718
using namespace mlir::torch;
@@ -71,6 +72,59 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
7172
return success();
7273
}
7374
};
75+
76+
class RecomposeSelectCopy_ : public OpRewritePattern<AtenCopy_Op> {
77+
public:
78+
using OpRewritePattern::OpRewritePattern;
79+
LogicalResult matchAndRewrite(AtenCopy_Op op,
80+
PatternRewriter &rewriter) const override {
81+
if (!op.getSelf().getDefiningOp() ||
82+
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
83+
return failure();
84+
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
85+
86+
// Get input tensor type
87+
Value self = selectOp.getSelf();
88+
std::optional<unsigned> maybeRank = getTensorRank(self);
89+
if (!maybeRank)
90+
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
91+
unsigned rank = *maybeRank;
92+
93+
// Get indices
94+
int64_t dim;
95+
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
96+
return failure();
97+
int64_t index;
98+
if (!matchPattern(selectOp.getIndex(), m_TorchConstantInt(&index)))
99+
return failure();
100+
101+
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
102+
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
103+
104+
// Create IndexPut_Op
105+
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
106+
SmallVector<Value> indicesVector;
107+
for (auto i = 0; i < rank; i++) {
108+
if(i == dim){
109+
indicesVector.push_back(selectOp.getIndex());
110+
} else {
111+
indicesVector.push_back(noneVal);
112+
}
113+
}
114+
115+
Value indices = rewriter.create<PrimListConstructOp>(
116+
op.getLoc(),
117+
Torch::ListType::get(op->getContext(),
118+
Torch::OptionalType::get(tensorType)),
119+
indicesVector);
120+
121+
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
122+
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getSrc(),
123+
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
124+
125+
return success();
126+
}
127+
};
74128
} // namespace
75129

76130
namespace {
@@ -83,6 +137,7 @@ class RecomposeComplexOpsPass
83137

84138
// pattern.add calls go here
85139
patterns.add<RecomposeSliceCopy_>(context);
140+
patterns.add<RecomposeSelectCopy_>(context);
86141

87142
GreedyRewriteConfig config;
88143
config.useTopDownTraversal = true;

0 commit comments

Comments
 (0)