@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
9595// / - affine_map<(d0, d1) -> (d0 * 3 + d1)>
9696// / In the future, more general interfaces can be devised to encode similar
9797// / shape evolutions and map between an op and its operands.
98- SmallVector<OpFoldResult> linalg::computePaddedShape (
99- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
100- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
101- const PadTilingInterfaceOptions &options) {
98+ SmallVector<OpFoldResult>
99+ linalg::computePaddedShape (OpBuilder &rewriter, TypedValue<RankedTensorType> v,
100+ AffineMap indexingMap,
101+ ArrayRef<OpFoldResult> indexingSizes,
102+ const PadTilingInterfaceOptions &options) {
102103 Location loc = v.getLoc ();
103104 SmallVector<OpFoldResult> paddedShape;
104105 auto tensorType = cast<RankedTensorType>(v.getType ());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
198199
199200FailureOr<SmallVector<OpFoldResult>>
200201linalg::computeIndexingMapOpInterfacePaddedShape (
201- RewriterBase &rewriter, OpOperand &operandToPad,
202+ OpBuilder &rewriter, OpOperand &operandToPad,
202203 ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
203204 auto transferOp =
204205 llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner ());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
224225
225226// / Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
226227// / Value.
227- static Value padOperand (RewriterBase &rewriter, TilingInterface opToPad,
228+ static Value padOperand (OpBuilder &rewriter, TilingInterface opToPad,
228229 TypedValue<RankedTensorType> v,
229230 ArrayRef<OpFoldResult> paddedShape,
230231 Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
263264 paddingValue, /* nofold=*/ false , dynDims);
264265}
265266
266- FailureOr<TilingInterface> linalg::rewriteAsPaddedOp (
267- RewriterBase &rewriter, TilingInterface opToPad,
268- const PadTilingInterfaceOptions &constOptions,
269- SmallVector<tensor::PadOp> &padOps,
267+ FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp (
268+ OpBuilder &builder, TilingInterface toPad,
269+ PadTilingInterfaceOptions options,
270270 const PadSizeComputationFunction &computePaddingSizeFun) {
271- LLVM_DEBUG (DBGS () << " Start rewriteAsPaddedOp : " << opToPad << " \n " );
271+ LLVM_DEBUG (DBGS () << " Start rewriteAsPaddedOp : " << toPad << " \n " );
272+ SmallVector<tensor::PadOp> padOps;
273+ Location loc = toPad.getLoc ();
272274
273- Location loc = opToPad.getLoc ();
274- PadTilingInterfaceOptions options (constOptions);
275275 // Allow inference of pad values if they are not explicitly specified.
276276 // TODO: be mindful about the value depending on the actual operation.
277277 if (options.paddingValues .empty ()) {
278- SmallVector<Type> types (opToPad ->getOperandTypes ());
279- llvm::append_range (types, opToPad ->getResultTypes ());
278+ SmallVector<Type> types (toPad ->getOperandTypes ());
279+ llvm::append_range (types, toPad ->getResultTypes ());
280280 for (Type t : types) {
281281 options.paddingValues .push_back (
282- rewriter .getZeroAttr (getElementTypeOrSelf (t)));
282+ builder .getZeroAttr (getElementTypeOrSelf (t)));
283283 }
284284 }
285285
286- if (llvm::any_of (opToPad ->getOperands (),
286+ if (llvm::any_of (toPad ->getOperands (),
287287 [](Value v) { return isa<MemRefType>(v.getType ()); })) {
288- return rewriter. notifyMatchFailure (opToPad,
289- " expected operation on tensors " );
288+ LLVM_DEBUG ( DBGS () << " Not an operation on tensors: FAIL \n " );
289+ return failure ( );
290290 }
291291
292- OpBuilder::InsertionGuard g (rewriter );
293- // Set IP after opToPad because we also take the dims of opToPad 's output.
294- rewriter .setInsertionPointAfter (opToPad );
292+ OpBuilder::InsertionGuard g (builder );
293+ // Set IP after toPad because we also take the dims of toPad 's output.
294+ builder .setInsertionPointAfter (toPad );
295295
296296 // 1. Get the loopUpperBounds from the TilingInterface.
297- SmallVector<Range> iterationDomain = opToPad .getIterationDomain (rewriter );
297+ SmallVector<Range> iterationDomain = toPad .getIterationDomain (builder );
298298
299299 // 2. For each operand.
300300 SmallVector<Value> newOperands;
301- newOperands.reserve (opToPad ->getNumOperands ());
302- for (OpOperand &opOperand : opToPad ->getOpOperands ()) {
301+ newOperands.reserve (toPad ->getNumOperands ());
302+ for (OpOperand &opOperand : toPad ->getOpOperands ()) {
303303 Value operand = opOperand.get ();
304- LLVM_DEBUG (DBGS () << " --start padding oprd : " << operand << " \n " );
304+ LLVM_DEBUG (DBGS () << " --start padding operand : " << operand << " \n " );
305305
306306 // 2.a. Skip scalar-like operands.
307307 Type operandType = operand.getType ();
@@ -311,27 +311,29 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
311311 newOperands.push_back (operand);
312312 continue ;
313313 }
314+
314315 // 2.a. Compute padded shape.
315316 FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
316- computePaddingSizeFun (rewriter , opOperand, iterationDomain, options);
317+ computePaddingSizeFun (builder , opOperand, iterationDomain, options);
317318 if (failed (maybePaddedShape)) {
318- return rewriter.notifyMatchFailure (opToPad, " could not pad op" );
319+ LLVM_DEBUG (DBGS () << " Could not get padded shape of operand: FAIL\n " );
320+ return failure ();
319321 }
320322
321323 // 2.b. Expect proper `paddingValues`.
322324 // TODO: we may want to allow garbage padding in the future, in which case
323325 // we would just not assert.
324326 if (opOperand.getOperandNumber () >= options.paddingValues .size ()) {
325- return rewriter. notifyMatchFailure (opToPad,
326- " --no padding value specified " );
327+ LLVM_DEBUG ( DBGS () << " Too few padding values specified: FAIL \n " );
328+ return failure ( );
327329 }
328330 Attribute paddingValueAttr =
329331 options.paddingValues [opOperand.getOperandNumber ()];
330332
331333 // 2.c. Perform actual padding.
332- Value paddedOperand = padOperand (
333- rewriter, opToPad , cast<TypedValue<RankedTensorType>>(operand),
334- *maybePaddedShape, paddingValueAttr);
334+ Value paddedOperand =
335+ padOperand (builder, toPad , cast<TypedValue<RankedTensorType>>(operand),
336+ *maybePaddedShape, paddingValueAttr);
335337 LLVM_DEBUG (DBGS () << " --done padding operand: " << paddedOperand << " \n " );
336338
337339 // 2.d. Perform actual padding.
@@ -342,38 +344,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
342344
343345 // 3. Form the resulting tensor::ExtractSliceOp.
344346 ReifiedRankedShapedTypeDims reifiedResultShapes;
345- if (failed (reifyResultShapes (rewriter, opToPad, reifiedResultShapes))) {
346- LLVM_DEBUG (DBGS () << " --failed to reify result shapes -> FAIL\n " );
347- return rewriter.notifyMatchFailure (opToPad,
348- " failed to reify result shapes" );
347+ if (failed (reifyResultShapes (builder, toPad, reifiedResultShapes))) {
348+ LLVM_DEBUG (DBGS () << " Failed to reify result shapes: FAIL\n " );
349+ return failure ();
349350 }
350- assert (reifiedResultShapes.size () == opToPad ->getNumResults () &&
351+ assert (reifiedResultShapes.size () == toPad ->getNumResults () &&
351352 " expected same number of results" );
352353
353- // Clone `opToPad ` to operate on the statically padded shapes.
354+ // Clone `toPad ` to operate on the statically padded shapes.
354355 auto resultTensorTypes =
355- ValueRange (newOperands).take_back (opToPad ->getNumResults ()).getTypes ();
356- // clone **should** properly notify the rewriter .
356+ ValueRange (newOperands).take_back (toPad ->getNumResults ()).getTypes ();
357+ // clone **should** properly notify the builder .
357358 TilingInterface paddedOp =
358- clone (rewriter, opToPad , resultTensorTypes, newOperands);
359+ clone (builder, toPad , resultTensorTypes, newOperands);
359360 LLVM_DEBUG (DBGS () << " --cloned padded op: " << paddedOp << " \n " );
360361
361- // Recover the slice out of the new static results. This keeps the original
362- // opToPad around because it uses the dims of the original results.
362+ // Recover the slice out of the new static results.
363363 SmallVector<Value> paddedSubtensorResults;
364- paddedSubtensorResults.reserve (opToPad ->getNumResults ());
364+ paddedSubtensorResults.reserve (toPad ->getNumResults ());
365365 for (const auto &en : llvm::enumerate (paddedOp->getResults ())) {
366366 Value paddedResult = en.value ();
367367 int64_t resultNumber = en.index ();
368368 int64_t rank = cast<RankedTensorType>(paddedResult.getType ()).getRank ();
369- SmallVector<OpFoldResult> offsets (rank, rewriter .getIndexAttr (0 ));
370- SmallVector<OpFoldResult> strides (rank, rewriter .getIndexAttr (1 ));
369+ SmallVector<OpFoldResult> offsets (rank, builder .getIndexAttr (0 ));
370+ SmallVector<OpFoldResult> strides (rank, builder .getIndexAttr (1 ));
371371 paddedSubtensorResults.push_back (tensor::ExtractSliceOp::create (
372- rewriter , loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
372+ builder , loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
373373 strides));
374374 }
375375
376- rewriter.replaceOp (opToPad, paddedSubtensorResults);
377-
378- return paddedOp;
376+ return PadTilingInterfaceResult{padOps, paddedSubtensorResults, paddedOp};
379377}
0 commit comments