Skip to content

Commit a8eedd2

Browse files
AmosLewisAmosLewis
authored andcommitted
[MLIR] Fold aten select and fill_ pattern
1 parent c86f46b commit a8eedd2

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1313
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1414
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1516

1617
using namespace mlir;
1718
using namespace mlir::torch;
@@ -71,6 +72,56 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
7172
return success();
7273
}
7374
};
75+
76+
class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
77+
public:
78+
using OpRewritePattern::OpRewritePattern;
79+
LogicalResult matchAndRewrite(AtenFill_TensorOp op,
80+
PatternRewriter &rewriter) const override {
81+
if (!op.getSelf().getDefiningOp() ||
82+
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
83+
return failure();
84+
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
85+
86+
// Get indices
87+
int64_t dim;
88+
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
89+
return failure();
90+
91+
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
92+
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
93+
94+
// Create IndexPut_Op
95+
// Convert indexNum to indexTensor for the selectOp
96+
BaseTensorType selectOutTy = selectOp.getType().template cast<BaseTensorType>();
97+
SmallVector<int64_t> empty;
98+
auto dtype =
99+
getTypeForTorchType(selectOp.getContext(), selectOp.getIndex().getType());
100+
Type emptyTensorType = selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
101+
Value indexTensor = rewriter.create<PrimNumToTensorScalarOp>(
102+
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
103+
104+
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
105+
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
106+
SmallVector<Value> indicesVector;
107+
for (auto i = 0; i < dim-1; i++) {
108+
indicesVector.push_back(noneVal);
109+
}
110+
indicesVector.push_back(indexTensor);
111+
112+
Value indices = rewriter.create<PrimListConstructOp>(
113+
op.getLoc(),
114+
Torch::ListType::get(op->getContext(),
115+
Torch::OptionalType::get(tensorType)),
116+
indicesVector);
117+
118+
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
119+
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(),
120+
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
121+
122+
return success();
123+
}
124+
};
74125
} // namespace
75126

76127
namespace {
@@ -83,6 +134,7 @@ class RecomposeComplexOpsPass
83134

84135
// pattern.add calls go here
85136
patterns.add<RecomposeSliceCopy_>(context);
137+
patterns.add<RecomposeSelectFill_>(context);
86138

87139
GreedyRewriteConfig config;
88140
config.useTopDownTraversal = true;

0 commit comments

Comments
 (0)