8
8
#include " mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
9
9
10
10
#include " mlir/Dialect/Affine/IR/AffineOps.h"
11
+ #include " mlir/Dialect/Affine/Utils.h"
11
12
#include " mlir/Dialect/Arith/IR/Arith.h"
12
13
#include " mlir/Dialect/Arith/Utils/Utils.h"
13
14
#include " mlir/Dialect/Func/IR/FuncOps.h"
15
+ #include " mlir/Dialect/Linalg/Utils/Utils.h"
14
16
#include " mlir/Dialect/SCF/Utils/Utils.h"
15
17
#include " mlir/Dialect/Tensor/IR/Tensor.h"
16
18
#include " mlir/Dialect/Utils/IndexingUtils.h"
@@ -388,6 +390,64 @@ static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
388
390
return success ();
389
391
}
390
392
393
+ // Considering multi-level tensor.*SliceOp maybe based on different
394
+ // coordination, this utils compute the real SliceParameters coordinated on ROOT
395
+ // SliceOp. E.g
396
+ // %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
397
+ // %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
398
+ //
399
+ // where the coordination can be illustrated as follow:
400
+ //
401
+ // %3 ----------------------------------
402
+ // | | |
403
+ // | OFFSET2 | OFFSET1 |
404
+ // | ------ %0 |
405
+ // | |
406
+ // | |
407
+ // |------------------ %1 ------ |
408
+ // | | SIZE1 |
409
+ // | | |
410
+ // | | |
411
+ // | | ------- |
412
+ // |
413
+ //
414
+ // The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
415
+ static FailureOr<linalg::SliceParameters>
416
+ computeRealSliceParamCoordinatedRootSliceOp (
417
+ RewriterBase &rewriter, Location loc,
418
+ OffsetSizeAndStrideOpInterface candidateSliceOp,
419
+ MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
420
+ if (llvm::any_of (candidateSliceOp.getMixedStrides (), [](OpFoldResult stride) {
421
+ return !isConstantIntValue (stride, 1 );
422
+ })) {
423
+ return rewriter.notifyMatchFailure (candidateSliceOp,
424
+ " candidateSliceOp has stride" );
425
+ }
426
+ SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets ();
427
+ // real offsets equals to accumulative offsets of outer candidates
428
+ for (auto iter = candidateSliceOpList.rbegin (); *iter != candidateSliceOp;
429
+ iter++) {
430
+ // assert each outer candidate slice has no stride
431
+ if (llvm::any_of (iter->getMixedStrides (), [](OpFoldResult stride) {
432
+ return !isConstantIntValue (stride, 1 );
433
+ })) {
434
+ return failure ();
435
+ }
436
+ for (auto &&[ofr1, ofr2] :
437
+ llvm::zip_equal (realOffsets, iter->getMixedOffsets ())) {
438
+ using AVE = affine::AffineValueExpr;
439
+ affine::AffineBuilder ab (rewriter, loc);
440
+ AffineExpr dim0, dim1, sym;
441
+ bindDims (rewriter.getContext (), dim0, dim1);
442
+ bindSymbols (rewriter.getContext (), sym);
443
+ auto aveOffset1 = AVE (dim0).bind (ofr1), aveOffset2 = AVE (dim1).bind (ofr2);
444
+ ofr1 = ab.add (aveOffset1, aveOffset2);
445
+ }
446
+ }
447
+ return linalg::SliceParameters{realOffsets, candidateSliceOp.getMixedSizes (),
448
+ candidateSliceOp.getMixedStrides ()};
449
+ }
450
+
391
451
// / Implementation of fusing consumer of a single slice by computing the
392
452
// / slice of the consumer in-place for scf loop.
393
453
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -401,7 +461,8 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
401
461
402
462
// 1. Get the real consumer of candidate
403
463
// tensor.insert_slice/parallel_insert_slice by walking through
404
- // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the way
464
+ // scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
465
+ // way.
405
466
FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
406
467
resultAndSliceOpsPair = scfX::getResultOfTopLevelLoopYieldInsertSliceOp (
407
468
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
@@ -422,7 +483,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
422
483
consumerOp, " consumer op's operand doesn't seem to be an OpResult" );
423
484
}
424
485
425
- // 2. Get all outer loops of candidateSliceOp
486
+ // 2. Get all outer loops of candidateSliceOp.
426
487
SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp (
427
488
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
428
489
@@ -445,7 +506,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
445
506
}
446
507
ValueRange newInitAppend = dpsInits;
447
508
448
- // 4. reconstruct nested loop from outer to inner
509
+ // 4. reconstruct nested loop from outer to inner.
449
510
SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
450
511
(*resultAndSliceOpsPair).second ;
451
512
SmallVector<LoopLikeOpInterface> newOuterLoops;
@@ -513,7 +574,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
513
574
newOuterLoops.push_back (newLoopOp);
514
575
}
515
576
516
- // 5.a reconstruct inner-most loop
577
+ // 5.a reconstruct inner-most loop.
517
578
LoopLikeOpInterface oldInnerMostLoop = outerLoops.back (), newInnerMostLoop;
518
579
Location loc = oldInnerMostLoop->getLoc ();
519
580
rewriter.setInsertionPoint (oldInnerMostLoop);
@@ -545,7 +606,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
545
606
rewriter.mergeBlocks (oldLoopBody, newLoopBody,
546
607
newLoopBody->getArguments ().take_front (oldNumArguments));
547
608
// 5.c replace the result of old oldInnerMostLoop with newInnerMostLoop's
548
- // results
609
+ // results.
549
610
rewriter.replaceOp (oldInnerMostLoop,
550
611
newInnerMostLoop->getResults ().take_front (
551
612
oldInnerMostLoop->getNumResults ()));
@@ -555,20 +616,26 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
555
616
// candidateSliceOp whereas in the scf.forall case this is created from the
556
617
// operands of tensor.parallel_insert_slice.
557
618
tensor::InsertSliceOp clonedInsertSliceOp;
558
- // we need to compute real offset and size for multi-level insertSliceOp
559
- // according the candidateSliceOpList
560
619
if (auto sliceOp =
561
620
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
562
621
auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
563
622
rewriter.setInsertionPoint (newForallOp.getTerminator ());
564
- clonedInsertSliceOp = rewriter.create <tensor::InsertSliceOp>(
565
- loc, sliceOp.getSource (), sliceOp.getDest (), sliceOp.getMixedOffsets (),
566
- sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
567
623
} else {
568
624
rewriter.setInsertionPoint (candidateSliceOp);
569
- clonedInsertSliceOp =
570
- cast<tensor::InsertSliceOp>(rewriter.clone (*candidateSliceOp));
571
625
}
626
+ FailureOr<linalg::SliceParameters> realSliceParams =
627
+ computeRealSliceParamCoordinatedRootSliceOp (
628
+ rewriter, loc, cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp),
629
+ candidateSliceOpList);
630
+ if (failed (realSliceParams))
631
+ return failure ();
632
+ // create dummy insertSliceOp to align with the requirement of current
633
+ // Tiling interface and fix potential semantic mismatch with later
634
+ // extractSliceOp generated by `getTiledImplementation`.
635
+ clonedInsertSliceOp = rewriter.create <tensor::InsertSliceOp>(
636
+ loc, candidateSliceOp->getOperand (0 ), candidateSliceOp->getOperand (1 ),
637
+ (*realSliceParams).offsets , (*realSliceParams).sizes ,
638
+ (*realSliceParams).strides );
572
639
573
640
// 7.a. Clone consumer op.
574
641
auto newForOpBlockArgsForConsumerDest =
@@ -623,7 +690,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
623
690
624
691
newOuterLoops.push_back (cast<LoopLikeOpInterface>(newInnerMostLoop));
625
692
626
- // 10. reconstruct terminator of outer loop by inner loop
693
+ // 10. reconstruct terminator of outer loop by inner loop.
627
694
auto outerCandidateIter = candidateSliceOpList.rbegin ();
628
695
for (auto [outerLoop, innerLoop] :
629
696
llvm::zip_equal (MutableArrayRef (newOuterLoops).drop_back (),
0 commit comments