Skip to content

Commit e839fb2

Browse files
committed
add default tileSize option to pass
1 parent 111d276 commit e839fb2

File tree

2 files changed

+99
-50
lines changed

2 files changed

+99
-50
lines changed

include/gc/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
8585
Option<"useCostModel", "use-cost-model", "bool",
8686
/*default=*/"false",
8787
"Decide if enable cost model to control iterative fusion.">,
88+
ListOption<"defaultTileSize", "default-tile-size", "std::string",
89+
"Set default TileSize for the certain type of op, saying matmul:{32,32}">,
8890
];
8991
}
9092

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 97 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Transforms/RegionUtils.h"
2929
#include <llvm/Support/Debug.h>
3030
#include <memory>
31+
#include <unordered_map>
3132

3233
#include "TilingUsingInterfaceX.h"
3334

@@ -601,45 +602,6 @@ static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
601602
return success(walkResult.wasInterrupted());
602603
}
603604

604-
template <typename OpTy>
605-
static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
606-
// a. Check <OpTy>
607-
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
608-
return false;
609-
auto tilingInterfaceOp = cast<TilingInterface>(op);
610-
611-
scf::SCFTilingOptions options;
612-
// b. Get default tiling size
613-
SmallVector<utils::IteratorType> iteratorTypes =
614-
tilingInterfaceOp.getLoopIteratorTypes();
615-
616-
SmallVector<OpFoldResult> defaultTileSize(iteratorTypes.size(),
617-
rewriter.getIndexAttr(0));
618-
619-
for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
620-
// All outer non reduction loop should contribute parallelism. In another
621-
// word, all reduction dimensions should not be tiled.
622-
if (iterType == utils::IteratorType::parallel &&
623-
(en != iteratorTypes.size() - 1 ||
624-
llvm::count(iteratorTypes, utils::IteratorType::reduction))) {
625-
defaultTileSize[en] = rewriter.getIndexAttr(1);
626-
}
627-
}
628-
629-
options.setTileSizes(defaultTileSize);
630-
// c. Set loop type
631-
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
632-
// d. Use builtin tiling interface
633-
FailureOr<scf::SCFTilingResult> tilingResult =
634-
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
635-
if (succeeded(tilingResult)) {
636-
rewriter.replaceOp(op, tilingResult->replacements);
637-
return true;
638-
} else {
639-
return false;
640-
}
641-
}
642-
643605
struct SystemDesc {
644606
// get runtime OMP_NUM_THREADS
645607
uint32_t getNumThreads() {
@@ -696,9 +658,61 @@ struct SystemDesc {
696658
MLIRContext *ctx;
697659
};
698660

661+
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
662+
663+
template <typename OpTy>
664+
static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
665+
const OpTileSizeMap &tsMap) {
666+
// a. Check <OpTy>
667+
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
668+
return false;
669+
auto tilingInterfaceOp = cast<TilingInterface>(op);
670+
671+
scf::SCFTilingOptions options;
672+
// b. Get default tiling size
673+
SmallVector<utils::IteratorType> iteratorTypes =
674+
tilingInterfaceOp.getLoopIteratorTypes();
675+
676+
SmallVector<OpFoldResult> defaultTileSize;
677+
678+
std::string opName = op->getName().getStringRef().str();
679+
// Erase dialect name, such as Linalg or Tensor.
680+
opName.erase(0, opName.find(".") + 1);
681+
682+
if (tsMap.count(opName)) {
683+
SmallVector<int64_t> userDefaultTileSize = tsMap.find(opName)->second;
684+
defaultTileSize =
685+
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
686+
} else {
687+
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
688+
for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
689+
// All outer non reduction loop should contribute parallelism. In another
690+
// word, all reduction dimensions should not be tiled.
691+
if (iterType == utils::IteratorType::parallel &&
692+
(en != iteratorTypes.size() - 1 ||
693+
llvm::count(iteratorTypes, utils::IteratorType::reduction))) {
694+
defaultTileSize[en] = rewriter.getIndexAttr(1);
695+
}
696+
}
697+
}
698+
699+
options.setTileSizes(defaultTileSize);
700+
// c. Set loop type
701+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
702+
// d. Use builtin tiling interface
703+
FailureOr<scf::SCFTilingResult> tilingResult =
704+
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
705+
if (succeeded(tilingResult)) {
706+
rewriter.replaceOp(op, tilingResult->replacements);
707+
return true;
708+
} else {
709+
return false;
710+
}
711+
}
712+
699713
void iterativeTilingAndFusionUntilExhaustion(
700714
RewriterBase &rewriter, func::FuncOp &f,
701-
const CandidateSliceOptions &sliceOptions) {
715+
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
702716
// Collect untiled and tiled ops respectively
703717
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
704718

@@ -756,26 +770,57 @@ void iterativeTilingAndFusionUntilExhaustion(
756770
} else {
757771
// Auto tiling with default tile size if no tiled op found. Follow tiling
758772
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
759-
SmallVector<std::function<bool(RewriterBase &, Operation *)>>
773+
SmallVector<std::function<bool(RewriterBase &, Operation *,
774+
const OpTileSizeMap &)>>
760775
priorityTilingPipeLine = {
761776
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
762777
defaultTilingOfType<mlir::linalg::ReduceOp>,
763778
defaultTilingOfType<mlir::linalg::LinalgOp>};
764-
if (llvm::all_of(
765-
priorityTilingPipeLine,
766-
[&rewriter, &unTiledOps](
767-
function_ref<bool(RewriterBase &, Operation *)> tilingFn) {
768-
return !llvm::any_of(unTiledOps,
769-
std::bind(tilingFn, std::ref(rewriter),
770-
std::placeholders::_1));
771-
})) {
779+
if (llvm::all_of(priorityTilingPipeLine,
780+
[&rewriter, &tsMap, &unTiledOps](
781+
function_ref<bool(RewriterBase &, Operation *,
782+
const OpTileSizeMap &)>
783+
tilingFn) {
784+
return !llvm::any_of(
785+
unTiledOps, std::bind(tilingFn, std::ref(rewriter),
786+
std::placeholders::_1,
787+
std::cref(tsMap)));
788+
})) {
772789
// If no op can be tiled
773790
break;
774791
}
775792
}
776793
}
777794
}
778795

796+
static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) {
797+
OpTileSizeMap tsMap;
798+
char warning[] =
799+
"Please follow correct argument format: opType:{ts1,ts2,...}";
800+
for (auto str : strArgs) {
801+
str.erase(llvm::remove_if(str, llvm::isSpace), str.end());
802+
size_t pos = str.find(":");
803+
if (pos == std::string::npos) {
804+
llvm_unreachable(warning);
805+
}
806+
std::string opType = str.substr(0, pos);
807+
std::string strTileSize = str.erase(0, pos + 1);
808+
if (strTileSize.size() <= 2 || strTileSize.front() != '{' ||
809+
strTileSize.back() != '}') {
810+
llvm_unreachable(warning);
811+
}
812+
strTileSize = strTileSize.substr(1, strTileSize.size() - 2);
813+
SmallVector<int64_t> intTileSize;
814+
while ((pos = strTileSize.find(",")) != std::string::npos) {
815+
intTileSize.push_back(std::stoi(strTileSize.substr(0, pos)));
816+
strTileSize.erase(0, pos + 1);
817+
}
818+
intTileSize.push_back(std::stoi(strTileSize));
819+
tsMap[opType] = intTileSize;
820+
}
821+
return tsMap;
822+
}
823+
779824
struct IterativeTilingAndFusion
780825
: public impl::IterativeTilingAndFusionBase<IterativeTilingAndFusion> {
781826
using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase;
@@ -808,10 +853,12 @@ struct IterativeTilingAndFusion
808853
};
809854
sliceOptions.addFilter(costModelFilter);
810855
}
856+
OpTileSizeMap tsMap = defaultTileSizeParser(defaultTileSize);
811857
// Get rewriter
812858
IRRewriter rewriter(&ctx);
813859
// Run iterative fusion
814-
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions);
860+
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions,
861+
tsMap);
815862
}
816863
};
817864

0 commit comments

Comments
 (0)