12
12
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
13
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
14
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
+ #include " torch-mlir/Dialect/Torch/Utils/Utils.h"
15
16
16
17
using namespace mlir ;
17
18
using namespace mlir ::torch;
@@ -71,6 +72,56 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
71
72
return success ();
72
73
}
73
74
};
75
+
76
+ class RecomposeSelectFill_ : public OpRewritePattern <AtenFill_TensorOp> {
77
+ public:
78
+ using OpRewritePattern::OpRewritePattern;
79
+ LogicalResult matchAndRewrite (AtenFill_TensorOp op,
80
+ PatternRewriter &rewriter) const override {
81
+ if (!op.getSelf ().getDefiningOp () ||
82
+ !isa<AtenSelectIntOp>(op.getSelf ().getDefiningOp ()))
83
+ return failure ();
84
+ auto selectOp = cast<AtenSelectIntOp>(op.getSelf ().getDefiningOp ());
85
+
86
+ // Get indices
87
+ int64_t dim;
88
+ if (!matchPattern (selectOp.getDim (), m_TorchConstantInt (&dim)))
89
+ return failure ();
90
+
91
+ Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
92
+ Value falseVal = rewriter.create <ConstantBoolOp>(op.getLoc (), false );
93
+
94
+ // Create IndexPut_Op
95
+ // Convert indexNum to indexTensor for the selectOp
96
+ BaseTensorType selectOutTy = selectOp.getType ().template cast <BaseTensorType>();
97
+ SmallVector<int64_t > empty;
98
+ auto dtype =
99
+ getTypeForTorchType (selectOp.getContext (), selectOp.getIndex ().getType ());
100
+ Type emptyTensorType = selectOutTy.getWithSizesAndDtype (llvm::ArrayRef (empty), dtype);
101
+ Value indexTensor = rewriter.create <PrimNumToTensorScalarOp>(
102
+ selectOp.getLoc (), emptyTensorType, selectOp.getIndex ());
103
+
104
+ // Create indicesVector for IndexPut_Op by TorchNone and indexTensor
105
+ BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
106
+ SmallVector<Value> indicesVector;
107
+ for (auto i = 0 ; i < dim-1 ; i++) {
108
+ indicesVector.push_back (noneVal);
109
+ }
110
+ indicesVector.push_back (indexTensor);
111
+
112
+ Value indices = rewriter.create <PrimListConstructOp>(
113
+ op.getLoc (),
114
+ Torch::ListType::get (op->getContext (),
115
+ Torch::OptionalType::get (tensorType)),
116
+ indicesVector);
117
+
118
+ rewriter.replaceOpWithNewOp <Aten_IndexPutImpl_Op>(
119
+ op, op->getResultTypes (), selectOp.getSelf (), indices, op.getValue (),
120
+ /* accumulate=*/ falseVal, /* unsafe=*/ falseVal);
121
+
122
+ return success ();
123
+ }
124
+ };
74
125
} // namespace
75
126
76
127
namespace {
@@ -83,6 +134,7 @@ class RecomposeComplexOpsPass
83
134
84
135
// pattern.add calls go here
85
136
patterns.add <RecomposeSliceCopy_>(context);
137
+ patterns.add <RecomposeSelectFill_>(context);
86
138
87
139
GreedyRewriteConfig config;
88
140
config.useTopDownTraversal = true ;
0 commit comments