|
28 | 28 | #include "mlir/Transforms/RegionUtils.h"
|
29 | 29 | #include <llvm/Support/Debug.h>
|
30 | 30 | #include <memory>
|
| 31 | +#include <unordered_map> |
31 | 32 |
|
32 | 33 | #include "TilingUsingInterfaceX.h"
|
33 | 34 |
|
@@ -601,45 +602,6 @@ static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
|
601 | 602 | return success(walkResult.wasInterrupted());
|
602 | 603 | }
|
603 | 604 |
|
604 |
| -template <typename OpTy> |
605 |
| -static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) { |
606 |
| - // a. Check <OpTy> |
607 |
| - if (!isa<TilingInterface>(op) || !isa<OpTy>(op)) |
608 |
| - return false; |
609 |
| - auto tilingInterfaceOp = cast<TilingInterface>(op); |
610 |
| - |
611 |
| - scf::SCFTilingOptions options; |
612 |
| - // b. Get default tiling size |
613 |
| - SmallVector<utils::IteratorType> iteratorTypes = |
614 |
| - tilingInterfaceOp.getLoopIteratorTypes(); |
615 |
| - |
616 |
| - SmallVector<OpFoldResult> defaultTileSize(iteratorTypes.size(), |
617 |
| - rewriter.getIndexAttr(0)); |
618 |
| - |
619 |
| - for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) { |
620 |
| - // All outer non reduction loop should contribute parallelism. In another |
621 |
| - // word, all reduction dimensions should not be tiled. |
622 |
| - if (iterType == utils::IteratorType::parallel && |
623 |
| - (en != iteratorTypes.size() - 1 || |
624 |
| - llvm::count(iteratorTypes, utils::IteratorType::reduction))) { |
625 |
| - defaultTileSize[en] = rewriter.getIndexAttr(1); |
626 |
| - } |
627 |
| - } |
628 |
| - |
629 |
| - options.setTileSizes(defaultTileSize); |
630 |
| - // c. Set loop type |
631 |
| - options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
632 |
| - // d. Use builtin tiling interface |
633 |
| - FailureOr<scf::SCFTilingResult> tilingResult = |
634 |
| - scf::tileUsingSCF(rewriter, tilingInterfaceOp, options); |
635 |
| - if (succeeded(tilingResult)) { |
636 |
| - rewriter.replaceOp(op, tilingResult->replacements); |
637 |
| - return true; |
638 |
| - } else { |
639 |
| - return false; |
640 |
| - } |
641 |
| -} |
642 |
| - |
643 | 605 | struct SystemDesc {
|
644 | 606 | // get runtime OMP_NUM_THREADS
|
645 | 607 | uint32_t getNumThreads() {
|
@@ -696,9 +658,61 @@ struct SystemDesc {
|
696 | 658 | MLIRContext *ctx;
|
697 | 659 | };
|
698 | 660 |
|
| 661 | +using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>; |
| 662 | + |
| 663 | +template <typename OpTy> |
| 664 | +static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op, |
| 665 | + const OpTileSizeMap &tsMap) { |
| 666 | + // a. Check <OpTy> |
| 667 | + if (!isa<TilingInterface>(op) || !isa<OpTy>(op)) |
| 668 | + return false; |
| 669 | + auto tilingInterfaceOp = cast<TilingInterface>(op); |
| 670 | + |
| 671 | + scf::SCFTilingOptions options; |
| 672 | + // b. Get default tiling size |
| 673 | + SmallVector<utils::IteratorType> iteratorTypes = |
| 674 | + tilingInterfaceOp.getLoopIteratorTypes(); |
| 675 | + |
| 676 | + SmallVector<OpFoldResult> defaultTileSize; |
| 677 | + |
| 678 | + std::string opName = op->getName().getStringRef().str(); |
| 679 | + // Erase dialect name, such as Linalg or Tensor. |
| 680 | + opName.erase(0, opName.find(".") + 1); |
| 681 | + |
| 682 | + if (tsMap.count(opName)) { |
| 683 | + SmallVector<int64_t> userDefaultTileSize = tsMap.find(opName)->second; |
| 684 | + defaultTileSize = |
| 685 | + getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize)); |
| 686 | + } else { |
| 687 | + defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0)); |
| 688 | + for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) { |
| 689 | + // All outer non reduction loop should contribute parallelism. In another |
| 690 | + // word, all reduction dimensions should not be tiled. |
| 691 | + if (iterType == utils::IteratorType::parallel && |
| 692 | + (en != iteratorTypes.size() - 1 || |
| 693 | + llvm::count(iteratorTypes, utils::IteratorType::reduction))) { |
| 694 | + defaultTileSize[en] = rewriter.getIndexAttr(1); |
| 695 | + } |
| 696 | + } |
| 697 | + } |
| 698 | + |
| 699 | + options.setTileSizes(defaultTileSize); |
| 700 | + // c. Set loop type |
| 701 | + options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
| 702 | + // d. Use builtin tiling interface |
| 703 | + FailureOr<scf::SCFTilingResult> tilingResult = |
| 704 | + scf::tileUsingSCF(rewriter, tilingInterfaceOp, options); |
| 705 | + if (succeeded(tilingResult)) { |
| 706 | + rewriter.replaceOp(op, tilingResult->replacements); |
| 707 | + return true; |
| 708 | + } else { |
| 709 | + return false; |
| 710 | + } |
| 711 | +} |
| 712 | + |
699 | 713 | void iterativeTilingAndFusionUntilExhaustion(
|
700 | 714 | RewriterBase &rewriter, func::FuncOp &f,
|
701 |
| - const CandidateSliceOptions &sliceOptions) { |
| 715 | + const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) { |
702 | 716 | // Collect untiled and tiled ops respectively
|
703 | 717 | llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
|
704 | 718 |
|
@@ -756,26 +770,57 @@ void iterativeTilingAndFusionUntilExhaustion(
|
756 | 770 | } else {
|
757 | 771 | // Auto tiling with default tile size if no tiled op found. Follow tiling
|
758 | 772 | // priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
|
759 |
| - SmallVector<std::function<bool(RewriterBase &, Operation *)>> |
| 773 | + SmallVector<std::function<bool(RewriterBase &, Operation *, |
| 774 | + const OpTileSizeMap &)>> |
760 | 775 | priorityTilingPipeLine = {
|
761 | 776 | defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
|
762 | 777 | defaultTilingOfType<mlir::linalg::ReduceOp>,
|
763 | 778 | defaultTilingOfType<mlir::linalg::LinalgOp>};
|
764 |
| - if (llvm::all_of( |
765 |
| - priorityTilingPipeLine, |
766 |
| - [&rewriter, &unTiledOps]( |
767 |
| - function_ref<bool(RewriterBase &, Operation *)> tilingFn) { |
768 |
| - return !llvm::any_of(unTiledOps, |
769 |
| - std::bind(tilingFn, std::ref(rewriter), |
770 |
| - std::placeholders::_1)); |
771 |
| - })) { |
| 779 | + if (llvm::all_of(priorityTilingPipeLine, |
| 780 | + [&rewriter, &tsMap, &unTiledOps]( |
| 781 | + function_ref<bool(RewriterBase &, Operation *, |
| 782 | + const OpTileSizeMap &)> |
| 783 | + tilingFn) { |
| 784 | + return !llvm::any_of( |
| 785 | + unTiledOps, std::bind(tilingFn, std::ref(rewriter), |
| 786 | + std::placeholders::_1, |
| 787 | + std::cref(tsMap))); |
| 788 | + })) { |
772 | 789 | // If no op can be tiled
|
773 | 790 | break;
|
774 | 791 | }
|
775 | 792 | }
|
776 | 793 | }
|
777 | 794 | }
|
778 | 795 |
|
| 796 | +static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) { |
| 797 | + OpTileSizeMap tsMap; |
| 798 | + char warning[] = |
| 799 | + "Please follow correct argument format: opType:{ts1,ts2,...}"; |
| 800 | + for (auto str : strArgs) { |
| 801 | + str.erase(llvm::remove_if(str, llvm::isSpace), str.end()); |
| 802 | + size_t pos = str.find(":"); |
| 803 | + if (pos == std::string::npos) { |
| 804 | + llvm_unreachable(warning); |
| 805 | + } |
| 806 | + std::string opType = str.substr(0, pos); |
| 807 | + std::string strTileSize = str.erase(0, pos + 1); |
| 808 | + if (strTileSize.size() <= 2 || strTileSize.front() != '{' || |
| 809 | + strTileSize.back() != '}') { |
| 810 | + llvm_unreachable(warning); |
| 811 | + } |
| 812 | + strTileSize = strTileSize.substr(1, strTileSize.size() - 2); |
| 813 | + SmallVector<int64_t> intTileSize; |
| 814 | + while ((pos = strTileSize.find(",")) != std::string::npos) { |
| 815 | + intTileSize.push_back(std::stoi(strTileSize.substr(0, pos))); |
| 816 | + strTileSize.erase(0, pos + 1); |
| 817 | + } |
| 818 | + intTileSize.push_back(std::stoi(strTileSize)); |
| 819 | + tsMap[opType] = intTileSize; |
| 820 | + } |
| 821 | + return tsMap; |
| 822 | +} |
| 823 | + |
779 | 824 | struct IterativeTilingAndFusion
|
780 | 825 | : public impl::IterativeTilingAndFusionBase<IterativeTilingAndFusion> {
|
781 | 826 | using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase;
|
@@ -808,10 +853,12 @@ struct IterativeTilingAndFusion
|
808 | 853 | };
|
809 | 854 | sliceOptions.addFilter(costModelFilter);
|
810 | 855 | }
|
| 856 | + OpTileSizeMap tsMap = defaultTileSizeParser(defaultTileSize); |
811 | 857 | // Get rewriter
|
812 | 858 | IRRewriter rewriter(&ctx);
|
813 | 859 | // Run iterative fusion
|
814 |
| - iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions); |
| 860 | + iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions, |
| 861 | + tsMap); |
815 | 862 | }
|
816 | 863 | };
|
817 | 864 |
|
|
0 commit comments