Skip to content

Commit e627983

Browse files
committed
add coordination on multi-level anchor
1 parent d130996 commit e627983

File tree

1 file changed

+80
-13
lines changed

1 file changed

+80
-13
lines changed

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
99

1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Affine/Utils.h"
1112
#include "mlir/Dialect/Arith/IR/Arith.h"
1213
#include "mlir/Dialect/Arith/Utils/Utils.h"
1314
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1416
#include "mlir/Dialect/SCF/Utils/Utils.h"
1517
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1618
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -388,6 +390,64 @@ static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
388390
return success();
389391
}
390392

393+
// Considering multi-level tensor.*SliceOp maybe based on different
394+
// coordination, this utils compute the real SliceParameters coordinated on ROOT
395+
// SliceOp. E.g
396+
// %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
397+
// %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
398+
//
399+
// where the coordination can be illustrated as follow:
400+
//
401+
// %3 ----------------------------------
402+
// | | |
403+
// | OFFSET2 | OFFSET1 |
404+
// | ------ %0 |
405+
// | |
406+
// | |
407+
// |------------------ %1 ------ |
408+
// | | SIZE1 |
409+
// | | |
410+
// | | |
411+
// | | ------- |
412+
// |
413+
//
414+
// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
415+
static FailureOr<linalg::SliceParameters>
416+
computeRealSliceParamCoordinatedRootSliceOp(
417+
RewriterBase &rewriter, Location loc,
418+
OffsetSizeAndStrideOpInterface candidateSliceOp,
419+
MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
420+
if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
421+
return !isConstantIntValue(stride, 1);
422+
})) {
423+
return rewriter.notifyMatchFailure(candidateSliceOp,
424+
"candidateSliceOp has stride");
425+
}
426+
SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
427+
// real offsets equals to accumulative offsets of outer candidates
428+
for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
429+
iter++) {
430+
// assert each outer candidate slice has no stride
431+
if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
432+
return !isConstantIntValue(stride, 1);
433+
})) {
434+
return failure();
435+
}
436+
for (auto &&[ofr1, ofr2] :
437+
llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
438+
using AVE = affine::AffineValueExpr;
439+
affine::AffineBuilder ab(rewriter, loc);
440+
AffineExpr dim0, dim1, sym;
441+
bindDims(rewriter.getContext(), dim0, dim1);
442+
bindSymbols(rewriter.getContext(), sym);
443+
auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
444+
ofr1 = ab.add(aveOffset1, aveOffset2);
445+
}
446+
}
447+
return linalg::SliceParameters{realOffsets, candidateSliceOp.getMixedSizes(),
448+
candidateSliceOp.getMixedStrides()};
449+
}
450+
391451
/// Implementation of fusing consumer of a single slice by computing the
392452
/// slice of the consumer in-place for scf loop.
393453
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -401,7 +461,8 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
401461

402462
// 1. Get the real consumer of candidate
403463
// tensor.insert_slice/parallel_insert_slice by walking through
404-
// scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the way
464+
// scf.for/scf.forall and collect all [Parallel]insertSliceOp(s) along the
465+
// way.
405466
FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
406467
resultAndSliceOpsPair = scfX::getResultOfTopLevelLoopYieldInsertSliceOp(
407468
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
@@ -422,7 +483,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
422483
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
423484
}
424485

425-
// 2. Get all outer loops of candidateSliceOp
486+
// 2. Get all outer loops of candidateSliceOp.
426487
SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsOfSliceOp(
427488
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
428489

@@ -445,7 +506,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
445506
}
446507
ValueRange newInitAppend = dpsInits;
447508

448-
// 4. reconstruct nested loop from outer to inner
509+
// 4. reconstruct nested loop from outer to inner.
449510
SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList =
450511
(*resultAndSliceOpsPair).second;
451512
SmallVector<LoopLikeOpInterface> newOuterLoops;
@@ -513,7 +574,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
513574
newOuterLoops.push_back(newLoopOp);
514575
}
515576

516-
// 5.a reconstruct inner-most loop
577+
// 5.a reconstruct inner-most loop.
517578
LoopLikeOpInterface oldInnerMostLoop = outerLoops.back(), newInnerMostLoop;
518579
Location loc = oldInnerMostLoop->getLoc();
519580
rewriter.setInsertionPoint(oldInnerMostLoop);
@@ -545,7 +606,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
545606
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
546607
newLoopBody->getArguments().take_front(oldNumArguments));
547608
// 5.c replace the result of old oldInnerMostLoop with newInnerMostLoop's
548-
// results
609+
// results.
549610
rewriter.replaceOp(oldInnerMostLoop,
550611
newInnerMostLoop->getResults().take_front(
551612
oldInnerMostLoop->getNumResults()));
@@ -555,20 +616,26 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
555616
// candidateSliceOp whereas in the scf.forall case this is created from the
556617
// operands of tensor.parallel_insert_slice.
557618
tensor::InsertSliceOp clonedInsertSliceOp;
558-
// we need to compute real offset and size for multi-level insertSliceOp
559-
// according the candidateSliceOpList
560619
if (auto sliceOp =
561620
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
562621
auto newForallOp = cast<scf::ForallOp>(newInnerMostLoop);
563622
rewriter.setInsertionPoint(newForallOp.getTerminator());
564-
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
565-
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
566-
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
567623
} else {
568624
rewriter.setInsertionPoint(candidateSliceOp);
569-
clonedInsertSliceOp =
570-
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
571625
}
626+
FailureOr<linalg::SliceParameters> realSliceParams =
627+
computeRealSliceParamCoordinatedRootSliceOp(
628+
rewriter, loc, cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp),
629+
candidateSliceOpList);
630+
if (failed(realSliceParams))
631+
return failure();
632+
// create dummy insertSliceOp to align with the requirement of current
633+
// Tiling interface and fix potential semantic mismatch with later
634+
// extractSliceOp generated by `getTiledImplementation`.
635+
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
636+
loc, candidateSliceOp->getOperand(0), candidateSliceOp->getOperand(1),
637+
(*realSliceParams).offsets, (*realSliceParams).sizes,
638+
(*realSliceParams).strides);
572639

573640
// 7.a. Clone consumer op.
574641
auto newForOpBlockArgsForConsumerDest =
@@ -623,7 +690,7 @@ mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
623690

624691
newOuterLoops.push_back(cast<LoopLikeOpInterface>(newInnerMostLoop));
625692

626-
// 10. reconstruct terminator of outer loop by inner loop
693+
// 10. reconstruct terminator of outer loop by inner loop.
627694
auto outerCandidateIter = candidateSliceOpList.rbegin();
628695
for (auto [outerLoop, innerLoop] :
629696
llvm::zip_equal(MutableArrayRef(newOuterLoops).drop_back(),

0 commit comments

Comments
 (0)