Skip to content

Commit 4df1d8a

Browse files
author
Chi_Liu
authored
[MLIR] Fold aten select and fill_ pattern (#2000)
1 parent 8dcd0b2 commit 4df1d8a

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1313
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1414
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1516

1617
using namespace mlir;
1718
using namespace mlir::torch;
@@ -71,6 +72,55 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
7172
return success();
7273
}
7374
};
75+
76+
class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
77+
public:
78+
using OpRewritePattern::OpRewritePattern;
79+
LogicalResult matchAndRewrite(AtenFill_TensorOp op,
80+
PatternRewriter &rewriter) const override {
81+
if (!op.getSelf().getDefiningOp() ||
82+
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
83+
return failure();
84+
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
85+
86+
// Get indices
87+
int64_t dim;
88+
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
89+
return failure();
90+
91+
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
92+
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
93+
94+
// Create IndexPut_Op
95+
// Convert indexNum to indexTensor for the selectOp
96+
BaseTensorType selectOutTy =
97+
selectOp.getType().template cast<BaseTensorType>();
98+
SmallVector<int64_t> empty;
99+
auto dtype = getTypeForTorchType(selectOp.getContext(),
100+
selectOp.getIndex().getType());
101+
Type emptyTensorType =
102+
selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
103+
Value indexTensor = rewriter.create<PrimNumToTensorScalarOp>(
104+
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
105+
106+
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
107+
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
108+
SmallVector<Value> indicesVector(dim - 1, noneVal);
109+
indicesVector.push_back(indexTensor);
110+
111+
Value indices = rewriter.create<PrimListConstructOp>(
112+
op.getLoc(),
113+
Torch::ListType::get(op->getContext(),
114+
Torch::OptionalType::get(tensorType)),
115+
indicesVector);
116+
117+
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
118+
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(),
119+
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
120+
121+
return success();
122+
}
123+
};
74124
} // namespace
75125

76126
namespace {
@@ -83,6 +133,7 @@ class RecomposeComplexOpsPass
83133

84134
// pattern.add calls go here
85135
patterns.add<RecomposeSliceCopy_>(context);
136+
patterns.add<RecomposeSelectFill_>(context);
86137

87138
GreedyRewriteConfig config;
88139
config.useTopDownTraversal = true;

0 commit comments

Comments
 (0)