@@ -27,6 +27,91 @@ using namespace mlir;
2727using namespace mlir ::torch;
2828using namespace mlir ::torch::Torch;
2929
30+ // Computes maxpool2d for AtenMaxPool2dOp and AtenMaxPool2dWithIndicesOp.
31+ template <typename OpTy>
32+ static LogicalResult computeMaxPool2d (OpTy op,
33+ ConversionPatternRewriter &rewriter,
34+ Value self, Value &result) {
35+ Location loc = op.getLoc ();
36+ Value ceilMode = op.ceil_mode ();
37+ Type elementType = self.getType ().cast <RankedTensorType>().getElementType ();
38+ if (!elementType.isa <mlir::FloatType>())
39+ return op.emitError (" unimplemented: non-floating point type" );
40+
41+ // Pattern match against the op's original operands, because otherwise we
42+ // will get the lowered version of the operands which is harder to pattern
43+ // match.
44+ SmallVector<int64_t , 2 > strideInts;
45+ if (!matchPattern (op.stride (), m_TorchConstantIntList (strideInts)))
46+ return rewriter.notifyMatchFailure (op, " only support constant int strides" );
47+ SmallVector<int64_t , 2 > dilationInts;
48+ if (!matchPattern (op.dilation (), m_TorchConstantIntList (dilationInts)))
49+ return rewriter.notifyMatchFailure (op,
50+ " only support constant int dilations" );
51+ SmallVector<int64_t , 2 > paddingInts;
52+ if (!matchPattern (op.padding (), m_TorchConstantIntList (paddingInts)))
53+ return rewriter.notifyMatchFailure (op,
54+ " only support constant int paddings" );
55+ SmallVector<int64_t , 2 > kernelSizeInts;
56+ if (!matchPattern (op.kernel_size (), m_TorchConstantIntList (kernelSizeInts)))
57+ return rewriter.notifyMatchFailure (op, " only support kernel size ints" );
58+ bool ceilModeFalse = false ;
59+ if (!matchPattern (op.ceil_mode (), m_TorchConstantBool (&ceilModeFalse)))
60+ return rewriter.notifyMatchFailure (op, " only ceil_mode false is supported" );
61+
62+ SmallVector<int64_t , 4 > paddingIncludingNC = {0 , 0 };
63+ paddingIncludingNC.insert (paddingIncludingNC.end (), paddingInts.begin (),
64+ paddingInts.end ());
65+ Value paddedInput =
66+ torch_to_linalg::getPaddedTensor (op, rewriter, self, paddingIncludingNC);
67+
68+ Value N = getDimOp (rewriter, loc, self, 0 );
69+ Value C = getDimOp (rewriter, loc, self, 1 );
70+ Value H = getDimOp (rewriter, loc, self, 2 );
71+ Value W = getDimOp (rewriter, loc, self, 3 );
72+
73+ SmallVector<Value> paddingIntValues =
74+ getAsConstantIntValues (rewriter, loc, paddingInts);
75+ SmallVector<Value> dilationIntValues =
76+ getAsConstantIntValues (rewriter, loc, dilationInts);
77+ SmallVector<Value> kernelSizeIntValues =
78+ getAsConstantIntValues (rewriter, loc, kernelSizeInts);
79+ SmallVector<Value> strideIntValues =
80+ getAsConstantIntValues (rewriter, loc, strideInts);
81+
82+ Value Hout = torch_to_linalg::getOutputDimForConvOps (
83+ rewriter, loc, H, paddingIntValues[0 ], dilationIntValues[0 ],
84+ kernelSizeIntValues[0 ], strideIntValues[0 ]);
85+ Value Wout = torch_to_linalg::getOutputDimForConvOps (
86+ rewriter, loc, W, paddingIntValues[1 ], dilationIntValues[1 ],
87+ kernelSizeIntValues[1 ], strideIntValues[1 ]);
88+
89+ // Initialize output tensor with smallest floating point value
90+ Value outTensor = rewriter.create <linalg::InitTensorOp>(
91+ loc, ValueRange{N, C, Hout, Wout}, elementType);
92+ auto initialAttr = rewriter.getFloatAttr (
93+ elementType, APFloat::getSmallest (
94+ elementType.cast <mlir::FloatType>().getFloatSemantics (),
95+ /* Negative*/ true ));
96+ Value initValue = rewriter.create <arith::ConstantOp>(loc, initialAttr);
97+ Value outTensorInitialized =
98+ rewriter.create <linalg::FillOp>(loc, initValue, outTensor).getResult (0 );
99+
100+ auto stridesAttr = rewriter.getI64VectorAttr (strideInts);
101+ auto dilationAttr = rewriter.getI64VectorAttr (dilationInts);
102+ Value windowTensor = rewriter.create <linalg::InitTensorOp>(
103+ loc, getAsConstantIndexValues (rewriter, loc, kernelSizeInts),
104+ elementType);
105+
106+ result = rewriter
107+ .create <linalg::PoolingNchwMaxOp>(
108+ loc, outTensorInitialized.getType (),
109+ ValueRange{paddedInput, windowTensor}, outTensorInitialized,
110+ stridesAttr, dilationAttr)
111+ .getResult (0 );
112+ return success ();
113+ }
114+
30115namespace {
31116class ConvertAtenMaxPool2dOp : public OpConversionPattern <AtenMaxPool2dOp> {
32117public:
@@ -36,94 +121,117 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
36121 ConversionPatternRewriter &rewriter) const override {
37122 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
38123 return failure ();
124+
125+ Value self = adaptor.self ();
126+ Value maxPool2d;
127+ if (failed (
128+ computeMaxPool2d<AtenMaxPool2dOp>(op, rewriter, self, maxPool2d)))
129+ return rewriter.notifyMatchFailure (op, " unable to compute maxpool2d" );
130+ Type newResultType = getTypeConverter ()->convertType (op.getType ());
131+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, maxPool2d);
132+ return success ();
133+ }
134+ };
135+ } // namespace
136+
137+ namespace {
138+ // Returns the result of maxpool2d over the input tensor. And the corresponding
139+ // indices of the input tensor for the values of the result tensor.
140+ //
141+ // The result of the maxpool2d operation is calculated using the helper function
142+ // written above. For finding the indices, we follow the below method:
143+ //
144+ // Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will
145+ // also be a 4-d tensor. Then:
146+ // for i in input.size[0]:
147+ // for j in input.size[1]:
148+ // for k in input.size[2]:
149+ // for l in input.size[3]:
150+ // for p in maxpool2d.size[1]:
151+ // for q in maxpool2d.size[2]:
152+ // for r in maxpool2d.size[3]:
153+ // if input[i, j, k, l] == maxpool2d[i, p, q, r]:
154+ // indices[i, p, q, r] = (k * input.size[3] + l)
155+ //
156+ class ConvertAtenMaxPool2dWithIndicesOp
157+ : public OpConversionPattern<AtenMaxPool2dWithIndicesOp> {
158+ public:
159+ using OpConversionPattern::OpConversionPattern;
160+ LogicalResult
161+ matchAndRewrite (AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
162+ ConversionPatternRewriter &rewriter) const override {
163+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
164+ return failure ();
39165 Location loc = op->getLoc ();
40166 Value self = adaptor.self ();
41- Value ceilMode = adaptor.ceil_mode ();
42167
43- Type elementType = self.getType ().cast <RankedTensorType>().getElementType ();
44- if (!elementType.isa <mlir::FloatType>())
45- return op.emitError (" unimplemented: non-floating point type" );
168+ // Contains the result of maxpool2d operation over the input.
169+ Value maxPool2d;
170+ if (failed (computeMaxPool2d<AtenMaxPool2dWithIndicesOp>(op, rewriter, self,
171+ maxPool2d)))
172+ return rewriter.notifyMatchFailure (op, " unable to compute maxpool2d" );
46173
47- // Pattern match against the op's original operands, because otherwise we
48- // will get the lowered version of the operands which is harder to pattern
49- // match.
50- SmallVector<int64_t , 2 > strideInts;
51- if (!matchPattern (op.stride (), m_TorchConstantIntList (strideInts)))
52- return rewriter.notifyMatchFailure (op,
53- " only support constant int strides" );
54- SmallVector<int64_t , 2 > dilationInts;
55- if (!matchPattern (op.dilation (), m_TorchConstantIntList (dilationInts)))
56- return rewriter.notifyMatchFailure (op,
57- " only support constant int dilations" );
58- SmallVector<int64_t , 2 > paddingInts;
59- if (!matchPattern (op.padding (), m_TorchConstantIntList (paddingInts)))
60- return rewriter.notifyMatchFailure (op,
61- " only support constant int paddings" );
62- SmallVector<int64_t , 2 > kernelSizeInts;
63- if (!matchPattern (op.kernel_size (), m_TorchConstantIntList (kernelSizeInts)))
64- return rewriter.notifyMatchFailure (op, " only support kernel size ints" );
65-
66- Value falseValue = rewriter.create <arith::ConstantOp>(
67- loc, IntegerAttr::get (rewriter.getIntegerType (1 ), 0 ));
68- Value ceilModeFalse = rewriter.create <arith::CmpIOp>(
69- loc, arith::CmpIPredicate::eq, ceilMode, falseValue);
70- rewriter.create <cf::AssertOp>(
71- loc, ceilModeFalse,
72- rewriter.getStringAttr (" only ceil_mode false is supported" ));
73-
74- SmallVector<int64_t , 4 > paddingIncludingNC = {0 , 0 };
75- paddingIncludingNC.insert (paddingIncludingNC.end (), paddingInts.begin (),
76- paddingInts.end ());
77- Value paddedInput = torch_to_linalg::getPaddedTensor (op, rewriter, self,
78- paddingIncludingNC);
79-
80- Value N = getDimOp (rewriter, loc, self, 0 );
81- Value C = getDimOp (rewriter, loc, self, 1 );
82- Value H = getDimOp (rewriter, loc, self, 2 );
83- Value W = getDimOp (rewriter, loc, self, 3 );
84-
85- SmallVector<Value> paddingIntValues =
86- getAsConstantIntValues (rewriter, loc, paddingInts);
87- SmallVector<Value> dilationIntValues =
88- getAsConstantIntValues (rewriter, loc, dilationInts);
89- SmallVector<Value> kernelSizeIntValues =
90- getAsConstantIntValues (rewriter, loc, kernelSizeInts);
91- SmallVector<Value> strideIntValues =
92- getAsConstantIntValues (rewriter, loc, strideInts);
93-
94- Value Hout = torch_to_linalg::getOutputDimForConvOps (
95- rewriter, loc, H, paddingIntValues[0 ], dilationIntValues[0 ],
96- kernelSizeIntValues[0 ], strideIntValues[0 ]);
97- Value Wout = torch_to_linalg::getOutputDimForConvOps (
98- rewriter, loc, W, paddingIntValues[1 ], dilationIntValues[1 ],
99- kernelSizeIntValues[1 ], strideIntValues[1 ]);
100-
101- // Initialize output tensor with smallest floating point value
102- Value outTensor = rewriter.create <linalg::InitTensorOp>(
103- loc, ValueRange{N, C, Hout, Wout}, elementType);
104- auto initialAttr = rewriter.getFloatAttr (
105- elementType,
106- APFloat::getSmallest (
107- elementType.cast <mlir::FloatType>().getFloatSemantics (),
108- /* Negative*/ true ));
109- Value initValue = rewriter.create <arith::ConstantOp>(loc, initialAttr);
110- Value outTensorInitialized =
111- rewriter.create <linalg::FillOp>(loc, initValue, outTensor).getResult (0 );
112-
113- auto stridesAttr = rewriter.getI64VectorAttr (strideInts);
114- auto dilationAttr = rewriter.getI64VectorAttr (dilationInts);
115- Value windowTensor = rewriter.create <linalg::InitTensorOp>(
116- loc, getAsConstantIndexValues (rewriter, loc, kernelSizeInts),
117- elementType);
118-
119- Value maxPool2d = rewriter
120- .create <linalg::PoolingNchwMaxOp>(
121- loc, outTensorInitialized.getType (),
122- ValueRange{paddedInput, windowTensor},
123- outTensorInitialized, stridesAttr, dilationAttr)
124- .getResult (0 );
125- Type newResultType = getTypeConverter ()->convertType (op.getType ());
126- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, maxPool2d);
174+ RankedTensorType resultType = getTypeConverter ()
175+ ->convertType (op->getResult (0 ).getType ())
176+ .cast <RankedTensorType>();
177+ RankedTensorType indicesType = getTypeConverter ()
178+ ->convertType (op->getResult (1 ).getType ())
179+ .cast <RankedTensorType>();
180+ unsigned resultRank = resultType.getRank ();
181+ SmallVector<Value> inputShape (getTensorSizes (rewriter, loc, self));
182+ SmallVector<Value> resultShape (getTensorSizes (rewriter, loc, maxPool2d));
183+
184+ Value indicesTensor = createZeroInitTensor (rewriter, loc, resultShape,
185+ indicesType.getElementType ());
186+
187+ SmallVector<AffineExpr> inputExprs, maxPoolExprs, indicesExprs;
188+ SmallVector<StringRef> iteratorTypes (2 * resultRank - 1 ,
189+ getParallelIteratorTypeName ());
190+ AffineExpr zeroDimExpr = rewriter.getAffineDimExpr (0 );
191+ inputExprs.push_back (zeroDimExpr);
192+ maxPoolExprs.push_back (zeroDimExpr);
193+ indicesExprs.push_back (zeroDimExpr);
194+
195+ for (unsigned i = 1 ; i < resultRank; i++) {
196+ inputExprs.push_back (rewriter.getAffineDimExpr (i));
197+ maxPoolExprs.push_back (rewriter.getAffineDimExpr (i + resultRank - 1 ));
198+ indicesExprs.push_back (rewriter.getAffineDimExpr (i + resultRank - 1 ));
199+ }
200+
201+ auto indexingMaps =
202+ AffineMap::inferFromExprList ({inputExprs, maxPoolExprs, indicesExprs});
203+
204+ auto indicesResult =
205+ rewriter
206+ .create <linalg::GenericOp>(
207+ loc, /* resultTensorTypes=*/ indicesTensor.getType (),
208+ /* inputs=*/ ValueRange ({self, maxPool2d}),
209+ /* outputs=*/ indicesTensor,
210+ /* indexingMaps=*/ indexingMaps,
211+ /* iteratorTypes=*/ iteratorTypes,
212+ [&](OpBuilder &b, Location loc, ValueRange args) {
213+ Value out = args[2 ];
214+ Value index = b.create <linalg::IndexOp>(loc, resultRank - 2 );
215+ index = b.create <arith::MulIOp>(
216+ loc, index, inputShape[resultRank - 2 + 1 ]);
217+ index = b.create <arith::AddIOp>(
218+ loc, index,
219+ b.create <linalg::IndexOp>(loc, resultRank - 1 ));
220+ index = castIndexToInt (b, loc, index);
221+ Value predicate;
222+ if (resultType.getElementType ().isa <mlir::FloatType>())
223+ predicate = b.create <arith::CmpFOp>(
224+ loc, arith::CmpFPredicate::OEQ, args[0 ], args[1 ]);
225+ else
226+ predicate = b.create <arith::CmpIOp>(
227+ loc, arith::CmpIPredicate::eq, args[0 ], args[1 ]);
228+
229+ Value result =
230+ b.create <arith::SelectOp>(loc, predicate, index, out);
231+ b.create <linalg::YieldOp>(loc, result);
232+ })
233+ .getResult (0 );
234+ rewriter.replaceOp (op, {maxPool2d, indicesResult});
127235 return success ();
128236 }
129237};
@@ -234,6 +342,8 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
234342 MLIRContext *context = patterns.getContext ();
235343 target.addIllegalOp <AtenMaxPool2dOp>();
236344 patterns.add <ConvertAtenMaxPool2dOp>(typeConverter, context);
345+ target.addIllegalOp <AtenMaxPool2dWithIndicesOp>();
346+ patterns.add <ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
237347 target.addIllegalOp <AtenAdaptiveAvgPool2dOp>();
238348 patterns.add <ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
239349}
0 commit comments