12
12
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
13
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
14
14
#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
15
+ #include " torch-mlir/Dialect/Torch/Utils/Utils.h"
15
16
16
17
using namespace mlir ;
17
18
using namespace mlir ::torch;
@@ -71,6 +72,59 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
71
72
return success ();
72
73
}
73
74
};
75
+
76
+ class RecomposeSelectCopy_ : public OpRewritePattern <AtenCopy_Op> {
77
+ public:
78
+ using OpRewritePattern::OpRewritePattern;
79
+ LogicalResult matchAndRewrite (AtenCopy_Op op,
80
+ PatternRewriter &rewriter) const override {
81
+ if (!op.getSelf ().getDefiningOp () ||
82
+ !isa<AtenSelectIntOp>(op.getSelf ().getDefiningOp ()))
83
+ return failure ();
84
+ auto selectOp = cast<AtenSelectIntOp>(op.getSelf ().getDefiningOp ());
85
+
86
+ // Get input tensor type
87
+ Value self = selectOp.getSelf ();
88
+ std::optional<unsigned > maybeRank = getTensorRank (self);
89
+ if (!maybeRank)
90
+ return rewriter.notifyMatchFailure (op, " Unimplemented: unranked tensor" );
91
+ unsigned rank = *maybeRank;
92
+
93
+ // Get indices
94
+ int64_t dim;
95
+ if (!matchPattern (selectOp.getDim (), m_TorchConstantInt (&dim)))
96
+ return failure ();
97
+ int64_t index;
98
+ if (!matchPattern (selectOp.getIndex (), m_TorchConstantInt (&index)))
99
+ return failure ();
100
+
101
+ Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
102
+ Value falseVal = rewriter.create <ConstantBoolOp>(op.getLoc (), false );
103
+
104
+ // Create IndexPut_Op
105
+ BaseTensorType tensorType = op->getResultTypes ()[0 ].cast <BaseTensorType>();
106
+ SmallVector<Value> indicesVector;
107
+ for (auto i = 0 ; i < rank; i++) {
108
+ if (i == dim){
109
+ indicesVector.push_back (selectOp.getIndex ());
110
+ } else {
111
+ indicesVector.push_back (noneVal);
112
+ }
113
+ }
114
+
115
+ Value indices = rewriter.create <PrimListConstructOp>(
116
+ op.getLoc (),
117
+ Torch::ListType::get (op->getContext (),
118
+ Torch::OptionalType::get (tensorType)),
119
+ indicesVector);
120
+
121
+ rewriter.replaceOpWithNewOp <Aten_IndexPutImpl_Op>(
122
+ op, op->getResultTypes (), selectOp.getSelf (), indices, op.getSrc (),
123
+ /* accumulate=*/ falseVal, /* unsafe=*/ falseVal);
124
+
125
+ return success ();
126
+ }
127
+ };
74
128
} // namespace
75
129
76
130
namespace {
@@ -83,6 +137,7 @@ class RecomposeComplexOpsPass
83
137
84
138
// pattern.add calls go here
85
139
patterns.add <RecomposeSliceCopy_>(context);
140
+ patterns.add <RecomposeSelectCopy_>(context);
86
141
87
142
GreedyRewriteConfig config;
88
143
config.useTopDownTraversal = true ;
0 commit comments