12
12
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
13
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
14
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
+ #include " torch-mlir/Dialect/Torch/Utils/Utils.h"
15
16
16
17
using namespace mlir ;
17
18
using namespace mlir ::torch;
@@ -71,6 +72,55 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
71
72
return success ();
72
73
}
73
74
};
75
+
76
+ class RecomposeSelectFill_ : public OpRewritePattern <AtenFill_TensorOp> {
77
+ public:
78
+ using OpRewritePattern::OpRewritePattern;
79
+ LogicalResult matchAndRewrite (AtenFill_TensorOp op,
80
+ PatternRewriter &rewriter) const override {
81
+ if (!op.getSelf ().getDefiningOp () ||
82
+ !isa<AtenSelectIntOp>(op.getSelf ().getDefiningOp ()))
83
+ return failure ();
84
+ auto selectOp = cast<AtenSelectIntOp>(op.getSelf ().getDefiningOp ());
85
+
86
+ // Get indices
87
+ int64_t dim;
88
+ if (!matchPattern (selectOp.getDim (), m_TorchConstantInt (&dim)))
89
+ return failure ();
90
+
91
+ Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
92
+ Value falseVal = rewriter.create <ConstantBoolOp>(op.getLoc (), false );
93
+
94
+ // Create IndexPut_Op
95
+ // Convert indexNum to indexTensor for the selectOp
96
+ BaseTensorType selectOutTy =
97
+ selectOp.getType ().template cast <BaseTensorType>();
98
+ SmallVector<int64_t > empty;
99
+ auto dtype = getTypeForTorchType (selectOp.getContext (),
100
+ selectOp.getIndex ().getType ());
101
+ Type emptyTensorType =
102
+ selectOutTy.getWithSizesAndDtype (llvm::ArrayRef (empty), dtype);
103
+ Value indexTensor = rewriter.create <PrimNumToTensorScalarOp>(
104
+ selectOp.getLoc (), emptyTensorType, selectOp.getIndex ());
105
+
106
+ // Create indicesVector for IndexPut_Op by TorchNone and indexTensor
107
+ BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
108
+ SmallVector<Value> indicesVector (dim - 1 , noneVal);
109
+ indicesVector.push_back (indexTensor);
110
+
111
+ Value indices = rewriter.create <PrimListConstructOp>(
112
+ op.getLoc (),
113
+ Torch::ListType::get (op->getContext (),
114
+ Torch::OptionalType::get (tensorType)),
115
+ indicesVector);
116
+
117
+ rewriter.replaceOpWithNewOp <Aten_IndexPutImpl_Op>(
118
+ op, op->getResultTypes (), selectOp.getSelf (), indices, op.getValue (),
119
+ /* accumulate=*/ falseVal, /* unsafe=*/ falseVal);
120
+
121
+ return success ();
122
+ }
123
+ };
74
124
} // namespace
75
125
76
126
namespace {
@@ -83,6 +133,7 @@ class RecomposeComplexOpsPass
83
133
84
134
// pattern.add calls go here
85
135
patterns.add <RecomposeSliceCopy_>(context);
136
+ patterns.add <RecomposeSelectFill_>(context);
86
137
87
138
GreedyRewriteConfig config;
88
139
config.useTopDownTraversal = true ;
0 commit comments