Skip to content

Commit 43177a3

Browse files
authored
[Transform][Fusion] revert consumer fusion for tensor.pack (#280)
* revert consumer fusion for `tensor.pack`. * add `ND-Tile` knob to control granularity of tiling or fusion.
1 parent f12c806 commit 43177a3

File tree

3 files changed

+80
-36
lines changed

3 files changed

+80
-36
lines changed

include/gc/Transforms/Passes.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
6060
let description = [{
6161
The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support:
6262
1. Matmul fusion with element-wise/reduce/broadcast ops.
63-
2. Pre-op and post-op fusion.
64-
3. Multi-consumer and multi-producer support.
65-
4. Multiple level of nest loops and candidates.
63+
2. Producer and consumer fusion.
64+
3. Arbitrary topology, including residual pattern with multiple consumers .
65+
4. Nest loops structure with multiple level candidates.
6666
5. Flexible option to control the boundary of iterative process.
6767
6. Default tiling when no op is tiled before fusion.
6868
7. Cost-model to determine whether to fuse or not.
@@ -74,8 +74,11 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
7474
Option<"useCostModel", "use-cost-model", "bool",
7575
/*default=*/"false",
7676
"Decide if enable cost model to control iterative fusion.">,
77+
Option<"defaultNDTile", "default-nd-tile", "unsigned",
78+
/*default=*/"2",
79+
"Set default amount of non-one dimensions in TileSize, such as 1, 2[default, a.k.a. 2D-Tile], etc.">,
7780
ListOption<"defaultTileSize", "default-tile-size", "std::string",
78-
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`">,
81+
"Set default TileSize for the certain type of op, saying `matmul:{32,32}`.">,
7982
];
8083
}
8184
def DeepTileContractionOp

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,11 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
167167
tileSizesOnInnerDims =
168168
llvm::to_vector(ArrayRef(tileSizes).take_back(innerTiles.size()));
169169
} else {
170-
// Upstream doesn't implement `getTiledImplementationFromOperandTile`
171-
// interface of `packOp` so far. In another word, `packOp` could not be
172-
// fused as consumer. As a result, just return failure currently.
173-
return failure();
170+
// tileSize comes from OpOperand
171+
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
172+
for (auto &pos : innerDimPos) {
173+
tileSizesOnInnerDims.push_back(tileSizes[pos]);
174+
}
174175
}
175176
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp)) {
176177
innerTiles = unPackOp.getMixedTiles();
@@ -478,8 +479,8 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
478479
return std::nullopt;
479480

480481
// c. Check the producer of root source if is tilable.
481-
Operation *producer = realProducer->getDefiningOp<TilingInterface>();
482-
if (!producer)
482+
Operation *producerOp = realProducer->getDefiningOp<TilingInterface>();
483+
if (!producerOp)
483484
return std::nullopt;
484485

485486
CandidateDefOrUse defOrUse{*realProducer};
@@ -536,8 +537,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
536537
SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
537538
for (auto useOperand : *realConsumers) {
538539
// c. Check the consumer of top level result if is tilable.
539-
Operation *consumer = dyn_cast<TilingInterface>(useOperand->getOwner());
540-
if (!consumer)
540+
Operation *consumerOp = dyn_cast<TilingInterface>(useOperand->getOwner());
541+
if (!consumerOp)
541542
continue;
542543

543544
CandidateDefOrUse defOrUse{useOperand};
@@ -559,7 +560,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
559560
// f. Manually run cse on region which contains original consumer op in
560561
// avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
561562
// loops between next candidate sliceOp and tiled producer.
562-
(void)mlir::simplifyRegions(rewriter, {*consumer->getParentRegion()});
563+
(void)mlir::simplifyRegions(rewriter, {*consumerOp->getParentRegion()});
563564
}
564565
}
565566
if (fusedResultList.empty())
@@ -647,11 +648,18 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
647648

648649
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
649650

651+
struct defaultTileConfig {
652+
// OpTy-to-TileSize mapping
653+
OpTileSizeMap tsMap;
654+
// ND-tile size
655+
unsigned ndTile;
656+
};
657+
650658
/// Default Tiling function only effective for certain `OpTy` operation
651659
static FailureOr<scf::SCFTilingResult>
652660
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
653661
function_ref<bool(Operation *)> isaOpTy,
654-
const OpTileSizeMap &tsMap) {
662+
const defaultTileConfig &cfg) {
655663
// a. Check <OpTy>
656664
if (!isa<TilingInterface>(op) || !isaOpTy(op))
657665
return failure();
@@ -672,18 +680,20 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
672680
// Erase dialect name, such as Linalg or Tensor.
673681
opName.erase(0, opName.find(".") + 1);
674682

675-
if (tsMap.count(opName)) {
676-
SmallVector<int64_t> userDefaultTileSize = tsMap.find(opName)->second;
683+
if (cfg.tsMap.count(opName)) {
684+
SmallVector<int64_t> userDefaultTileSize = cfg.tsMap.find(opName)->second;
677685
defaultTileSize =
678686
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
679687
} else {
680688
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
681689
// Try tileSize from `32` to `16`.
682690
SmallVector<int64_t> tsOrder = {32, 16};
683-
// Only 2D tile is expected.
684-
int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
685-
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
686-
: 0;
691+
// Record how many dims have been tiled, including fully tiled, i.e.
692+
// tileSize == dimSize.
693+
unsigned nonOneTileDims =
694+
(isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
695+
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
696+
: 0;
687697
// Reverse both of iteration type and domain from inner to outer.
688698
std::reverse(iteratorTypes.begin(), iteratorTypes.end());
689699
std::reverse(iterationDomain.begin(), iterationDomain.end());
@@ -692,21 +702,29 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
692702
// All parallel iterator will be tiled by `32` or `16`. If need
693703
// specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
694704
if (iterType == utils::IteratorType::parallel) {
695-
Range curDomain = iterationDomain[en];
696-
std::optional<int64_t> tripCount = mlir::constantTripCount(
697-
curDomain.offset, curDomain.size, curDomain.stride);
698-
if (tileDims >= 2 && en > 0) {
705+
if (nonOneTileDims >= cfg.ndTile && en > 0) {
699706
defaultTileSize[en] = rewriter.getIndexAttr(1);
700707
continue;
701-
} else if (tripCount) {
708+
}
709+
Range curDomain = iterationDomain[en];
710+
if (std::optional<int64_t> tripCount = mlir::constantTripCount(
711+
curDomain.offset, curDomain.size, curDomain.stride)) {
712+
// skip dummy tiling.
713+
if (tripCount == 1)
714+
continue;
702715
for (auto &ts : tsOrder) {
703-
if (*tripCount % ts == 0 && *tripCount > ts) {
716+
// If `tripCount` equals to `tileSize`, Do NOT explicitly tile it in
717+
// avoid of non-zero offset.
718+
if (*tripCount == ts)
719+
break;
720+
if (*tripCount % ts == 0) {
704721
defaultTileSize[en] = rewriter.getIndexAttr(ts);
705722
break;
706723
}
707724
}
708725
}
709-
tileDims++;
726+
// Fallback to fully tiled.
727+
nonOneTileDims++;
710728
}
711729
}
712730
}
@@ -731,7 +749,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
731749

732750
void iterativeTilingAndFusionUntilExhaustion(
733751
RewriterBase &rewriter, func::FuncOp &f,
734-
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
752+
const CandidateSliceOptions &sliceOptions, const defaultTileConfig &cfg) {
735753
// Collect untiled and tiled ops respectively
736754
llvm::SetVector<Operation *> tiledOps, unTiledOps;
737755

@@ -799,7 +817,7 @@ void iterativeTilingAndFusionUntilExhaustion(
799817
for (auto &isaOpTy : priorityOpTypeOrder) {
800818
for (auto &op : unTiledOps) {
801819
FailureOr<scf::SCFTilingResult> tilingResult =
802-
defaultTilingOfType(rewriter, op, isaOpTy, tsMap);
820+
defaultTilingOfType(rewriter, op, isaOpTy, cfg);
803821
if (succeeded(tilingResult)) {
804822
tiledOps.insert(tilingResult->tiledOps[0]);
805823
rewriter.replaceOp(op, tilingResult->replacements);
@@ -881,8 +899,8 @@ struct IterativeTilingAndFusion
881899
// Get rewriter
882900
IRRewriter rewriter(&ctx);
883901
// Run iterative fusion
884-
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions,
885-
tsMap);
902+
iterativeTilingAndFusionUntilExhaustion(
903+
rewriter, func, sliceOptions, defaultTileConfig{tsMap, defaultNDTile});
886904
}
887905
};
888906

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,18 +339,41 @@ module {
339339

340340
module {
341341
/// CHECK-LABEL: @not_fuse_pack
342-
func.func @not_fuse_pack(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<1x32x4096xbf16>) -> tensor<1x1x128x32x32xbf16> {
343-
%dest0 = tensor.empty() : tensor<1x32x4096xbf16>
342+
func.func @not_fuse_pack(%arg0: tensor<1x35x4096xbf16>, %arg1: tensor<1x35x4096xbf16>) -> tensor<1x2x128x32x32xbf16> {
343+
%dest0 = tensor.empty() : tensor<1x35x4096xbf16>
344344
/// CHECK: scf.forall
345345
/// CHECK: linalg.add
346-
%add = linalg.add ins(%arg0, %arg1 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%dest0 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
346+
%add = linalg.add ins(%arg0, %arg1 : tensor<1x35x4096xbf16>, tensor<1x35x4096xbf16>) outs(%dest0 : tensor<1x35x4096xbf16>) -> tensor<1x35x4096xbf16>
347347
/// CHECK: }
348-
%dest1 = tensor.empty() : tensor<1x1x128x32x32xbf16>
348+
%dest1 = tensor.empty() : tensor<1x2x128x32x32xbf16>
349+
%pad = arith.constant 0.000000e+00 : bf16
349350
/// CHECK: %[[PACK_OUT:.*]] = scf.forall
350351
/// CHECK: tensor.pack
351-
%pack = tensor.pack %add outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x32x4096xbf16> -> tensor<1x1x128x32x32xbf16>
352+
%pack = tensor.pack %add padding_value(%pad : bf16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x35x4096xbf16> -> tensor<1x2x128x32x32xbf16>
352353
/// CHECK: }
353354
/// CHECK: return %[[PACK_OUT]]
355+
return %pack : tensor<1x2x128x32x32xbf16>
356+
}
357+
}
358+
359+
// -----
360+
361+
module {
362+
/// CHECK-LABEL: @fuse_pack
363+
func.func @fuse_pack(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<1x32x4096xbf16>) -> tensor<1x1x128x32x32xbf16> {
364+
%dest0 = tensor.empty() : tensor<1x32x4096xbf16>
365+
/// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) = (0, 0) to (1, 4096) step (1, 32)
366+
/// CHECK: linalg.add
367+
%add = linalg.add ins(%arg0, %arg1 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%dest0 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16>
368+
%dest1 = tensor.empty() : tensor<1x1x128x32x32xbf16>
369+
/// CHECK-NEXT: affine.apply
370+
/// CHECK-NEXT: tensor.extract_slice
371+
/// CHECK-NEXT: tensor.pack
372+
%pack = tensor.pack %add outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] into %dest1 : tensor<1x32x4096xbf16> -> tensor<1x1x128x32x32xbf16>
373+
/// CHECK: scf.forall.in_parallel
374+
/// CHECK: tensor.parallel_insert_slice
375+
/// CHECK: tensor.parallel_insert_slice
376+
/// CHECK: return %[[FINAL_RESULT]]#1
354377
return %pack : tensor<1x1x128x32x32xbf16>
355378
}
356379
}

0 commit comments

Comments
 (0)