12
12
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
13
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
14
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
+ #include " torch-mlir/Dialect/Torch/Utils/Utils.h"
15
16
16
17
using namespace mlir ;
17
18
using namespace mlir ::torch;
@@ -71,6 +72,53 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
71
72
return success ();
72
73
}
73
74
};
75
+
76
+ class RecomposeSelectFill_ : public OpRewritePattern <AtenFill_TensorOp> {
77
+ public:
78
+ using OpRewritePattern::OpRewritePattern;
79
+ LogicalResult matchAndRewrite (AtenFill_TensorOp op,
80
+ PatternRewriter &rewriter) const override {
81
+ if (!op.getSelf ().getDefiningOp () ||
82
+ !isa<AtenSelectIntOp>(op.getSelf ().getDefiningOp ()))
83
+ return failure ();
84
+ auto selectOp = cast<AtenSelectIntOp>(op.getSelf ().getDefiningOp ());
85
+
86
+ // Get indices
87
+ int64_t dim;
88
+ if (!matchPattern (selectOp.getDim (), m_TorchConstantInt (&dim)))
89
+ return failure ();
90
+
91
+ Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
92
+ Value falseVal = rewriter.create <ConstantBoolOp>(op.getLoc (), false );
93
+
94
+ // Create IndexPut_Op
95
+ // Convert indexNum to indexTensor for the selectOp
96
+ BaseTensorType selectOutTy = selectOp.getType ().template cast <BaseTensorType>();
97
+ SmallVector<int64_t > empty;
98
+ auto dtype =
99
+ getTypeForTorchType (selectOp.getContext (), selectOp.getIndex ().getType ());
100
+ Type emptyTensorType = selectOutTy.getWithSizesAndDtype (llvm::ArrayRef (empty), dtype);
101
+ Value indexTensor = rewriter.create <PrimNumToTensorScalarOp>(
102
+ selectOp.getLoc (), emptyTensorType, selectOp.getIndex ());
103
+
104
+ // Create indicesVector for IndexPut_Op by TorchNone and indexTensor
105
+ BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
106
+ SmallVector<Value> indicesVector (dim-1 , noneVal);
107
+ indicesVector.push_back (indexTensor);
108
+
109
+ Value indices = rewriter.create <PrimListConstructOp>(
110
+ op.getLoc (),
111
+ Torch::ListType::get (op->getContext (),
112
+ Torch::OptionalType::get (tensorType)),
113
+ indicesVector);
114
+
115
+ rewriter.replaceOpWithNewOp <Aten_IndexPutImpl_Op>(
116
+ op, op->getResultTypes (), selectOp.getSelf (), indices, op.getValue (),
117
+ /* accumulate=*/ falseVal, /* unsafe=*/ falseVal);
118
+
119
+ return success ();
120
+ }
121
+ };
74
122
} // namespace
75
123
76
124
namespace {
@@ -83,6 +131,7 @@ class RecomposeComplexOpsPass
83
131
84
132
// pattern.add calls go here
85
133
patterns.add <RecomposeSliceCopy_>(context);
134
+ patterns.add <RecomposeSelectFill_>(context);
86
135
87
136
GreedyRewriteConfig config;
88
137
config.useTopDownTraversal = true ;
0 commit comments