Skip to content

[Transform][Fusion] optimize default tileSize to 2D-tile #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 68 additions & 43 deletions lib/gc/Transforms/IterativeTilingAndFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "gc/Analysis/TargetDescriptionAnalysis.h"
#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "gc/Dialect/Linalgx/Utils.h"
#include "gc/Transforms/Passes.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/Traits.h"
Expand Down Expand Up @@ -166,11 +167,10 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
tileSizesOnInnerDims =
llvm::to_vector(ArrayRef(tileSizes).take_back(innerTiles.size()));
} else {
// tileSize comes from OpOperand
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
for (auto &pos : innerDimPos) {
tileSizesOnInnerDims.push_back(tileSizes[pos]);
}
// Upstream doesn't implement `getTiledImplementationFromOperandTile`
// interface of `packOp` so far. In another word, `packOp` could not be
// fused as consumer. As a result, just return failure currently.
return failure();
}
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp)) {
innerTiles = unPackOp.getMixedTiles();
Expand Down Expand Up @@ -215,7 +215,7 @@ nonContractionOpFilter(RewriterBase &rewriter,
CandidateDefOrUse defOrUse) {
// Currently this pass focuses on fine-grained fusion, which does not expect
// two consecutive contraction ops.
return failure(isa<mlir::linalg::ContractionOpInterface>(defOrUse.ownerOp));
return failure(linalgx::isMatmulOp(defOrUse.ownerOp));
}

/// If fusing multiple consumers is allowed, there may exist following cases:
Expand Down Expand Up @@ -635,29 +635,36 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
// 3. check whether has either extract or insert slice op
auto walkResult = forOp->walk(
[](tensor::ExtractSliceOp) { return WalkResult::interrupt(); });
if (walkResult.wasInterrupted())
return success();
walkResult = forOp->walk(
[](tensor::InsertSliceOp) { return WalkResult::interrupt(); });
if (!walkResult.wasInterrupted())
return failure();
walkResult = forOp->walk([](OffsetSizeAndStrideOpInterface op) {
return isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(op)
? WalkResult::interrupt()
: WalkResult::advance();
});
return success(walkResult.wasInterrupted());
}

using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;

/// Default Tiling function only effective for certain `OpTy` operation
template <typename OpTy>
static FailureOr<scf::SCFTilingResult>
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
function_ref<bool(Operation *)> isaOpTy,
const OpTileSizeMap &tsMap) {
// a. Check <OpTy>
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
if (!isa<TilingInterface>(op) || !isaOpTy(op))
return failure();
auto tilingInterfaceOp = cast<TilingInterface>(op);

scf::SCFTilingOptions options;
// b. Get default tiling size
SmallVector<utils::IteratorType> iteratorTypes =
tilingInterfaceOp.getLoopIteratorTypes();
llvm::SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(rewriter);
assert(iteratorTypes.size() == iterationDomain.size() &&
"Iteration domain expected as same long as iteration type");

SmallVector<OpFoldResult> defaultTileSize;

Expand All @@ -671,15 +678,40 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
} else {
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
// Try tileSize from `32` to `16`.
SmallVector<int64_t> tsOrder = {32, 16};
// Only 2D tile is expected.
int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp(op))
? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops()
: 0;
// Reverse both of iteration type and domain from inner to outer.
std::reverse(iteratorTypes.begin(), iteratorTypes.end());
std::reverse(iterationDomain.begin(), iterationDomain.end());

for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
// All outer non reduction loop should contribute parallelism. In another
// word, all reduction dimensions should not be tiled.
if (iterType == utils::IteratorType::parallel &&
(en != iteratorTypes.size() - 1 ||
llvm::count(iteratorTypes, utils::IteratorType::reduction)))
defaultTileSize[en] = rewriter.getIndexAttr(1);
// All parallel iterator will be tiled by `32` or `16`. If need
// specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
if (iterType == utils::IteratorType::parallel) {
Range curDomain = iterationDomain[en];
std::optional<int64_t> tripCount = mlir::constantTripCount(
curDomain.offset, curDomain.size, curDomain.stride);
if (tileDims >= 2 && en > 0) {
defaultTileSize[en] = rewriter.getIndexAttr(1);
continue;
} else if (tripCount) {
for (auto &ts : tsOrder) {
if (*tripCount % ts == 0 && *tripCount > ts) {
defaultTileSize[en] = rewriter.getIndexAttr(ts);
break;
}
}
}
tileDims++;
}
}
}
// Reverse back default TileSize.
std::reverse(defaultTileSize.begin(), defaultTileSize.end());
// If the tile sizes are all zero, no tiling would happen.
if (llvm::all_of(defaultTileSize, isZeroIndex))
return failure();
Expand All @@ -697,20 +729,6 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
return tilingResult;
}

template <typename OpTy1, typename OpTy2, typename... Rest>
static FailureOr<scf::SCFTilingResult>
defaultTilingOfType(RewriterBase &rewriter, Operation *op,
const OpTileSizeMap &tsMap) {
FailureOr<scf::SCFTilingResult> tilingResult =
defaultTilingOfType<OpTy1>(rewriter, op, tsMap);
if (failed(tilingResult))
return defaultTilingOfType<OpTy2, Rest...>(rewriter, op, tsMap);
return tilingResult;
}

using DefaultTilingFn = std::function<FailureOr<scf::SCFTilingResult>(
RewriterBase &, Operation *, const OpTileSizeMap &)>;

void iterativeTilingAndFusionUntilExhaustion(
RewriterBase &rewriter, func::FuncOp &f,
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
Expand Down Expand Up @@ -738,9 +756,8 @@ void iterativeTilingAndFusionUntilExhaustion(

// Walk through funcOp
f->walk([&tiledOps](Operation *op) {
if (succeeded(isTiledOpInLoop(op))) {
if (succeeded(isTiledOpInLoop(op)))
tiledOps.insert(op);
}
});

// Iterative tiling and fusion until exhaustion.
Expand All @@ -764,17 +781,25 @@ void iterativeTilingAndFusionUntilExhaustion(
// Auto tiling with default tile size if no tiled op found. Follow tiling
// priority based on OpTy:
// `ContractionOp`->`ReductionOp`->`LinalgOp`->`TensorOp`.
SmallVector<DefaultTilingFn> priorityTilingPipeLine = {
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
defaultTilingOfType<mlir::linalg::ReduceOp>,
defaultTilingOfType<mlir::linalg::LinalgOp>,
defaultTilingOfType<tensor::PackOp, tensor::UnPackOp, tensor::PadOp>,
defaultTilingOfType<TilingInterface>};

for (auto &tilingFn : priorityTilingPipeLine) {
SmallVector<std::function<bool(Operation *)>> priorityOpTypeOrder = {
// Generate helper function to check if isa<OpTy>.
#define GenIsaOpTy(opTy) [](Operation *op) { return opTy(op); }
// If ContractionOp
GenIsaOpTy(linalgx::isMatmulOp),
// If ReduceOp
GenIsaOpTy(isa<mlir::linalg::ReduceOp>),
// If other LinalgOp
GenIsaOpTy(isa<mlir::linalg::LinalgOp>),
// If TensorOp
GenIsaOpTy((isa<tensor::PackOp, tensor::UnPackOp, tensor::PadOp>)),
// Fallback
GenIsaOpTy(isa<TilingInterface>)};
#undef GenIsaOpTy
mlir::topologicalSort(unTiledOps);
for (auto &isaOpTy : priorityOpTypeOrder) {
for (auto &op : unTiledOps) {
FailureOr<scf::SCFTilingResult> tilingResult =
tilingFn(rewriter, op, tsMap);
defaultTilingOfType(rewriter, op, isaOpTy, tsMap);
if (succeeded(tilingResult)) {
tiledOps.insert(tilingResult->tiledOps[0]);
rewriter.replaceOp(op, tilingResult->replacements);
Expand Down
71 changes: 62 additions & 9 deletions test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module {
/// CHECK: tensor.empty
%dest = tensor.empty() : tensor<512x256xbf16>
%unpack = tensor.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %dest : tensor<32x8x16x32xbf16> -> tensor<512x256xbf16>
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2)
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
%2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1) -> (tensor<128x256xbf16>) {
%5 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
%6 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg4)
Expand Down Expand Up @@ -105,7 +105,7 @@ module {
%cst = arith.constant 0.000000e+00 : f32
%dest0 = tensor.empty() : tensor<256x256xf32>
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2)
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 2)
%1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest1) -> tensor<256x256xf32> {
%iv0 = affine.apply #map(%arg4)
%iv1 = affine.apply #map(%arg5)
Expand Down Expand Up @@ -157,7 +157,7 @@ module {
%cst = arith.constant 0.000000e+00 : f32
%dest0 = tensor.empty() : tensor<256x256xf32>
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 1)
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (2, 1)
%1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
%iv0 = affine.apply #map(%arg3)
%iv1 = affine.apply #map(%arg4)
Expand Down Expand Up @@ -205,12 +205,12 @@ module {
%dest0 = tensor.empty() : tensor<128x256x256xf32>
%0 = linalg.add ins(%arg0, %arg1 : tensor<128x256x256xf32>, tensor<128x256x256xf32>) outs(%dest0 : tensor<128x256x256xf32>) -> tensor<128x256x256xf32>
%dest1 = tensor.empty() : tensor<128x256xf32>
/// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1]
/// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}) = (0, 0) to (128, 256) step (1, 32)
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
/// CHECK: tensor.extract_slice {{.*}} [1, 256, 32] [1, 1, 1]
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
/// CHECK: tensor.extract_slice {{.*}} [1, 1] [1, 1]
/// CHECK: tensor.extract_slice {{.*}} [1, 32] [1, 1]
/// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] :
%1 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256x256xf32>) outs(%dest1 : tensor<128x256xf32>) dimensions = [1]
/// CHECK: scf.forall.in_parallel
Expand Down Expand Up @@ -319,7 +319,7 @@ module {
/// CHECK-LABEL: @fuse_residual_pattern
func.func @fuse_residual_pattern(%arg0: tensor<128x256x256xf32>, %arg1: tensor<128x256x256xf32>) -> tensor<128x256x256xf32> {
%dest0 = tensor.empty() : tensor<128x256x256xf32>
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (128, 256)
/// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) = (0, 0, 0) to (128, 256, 256) step (1, 32, 32)
/// CHECK: %[[ADD_OUT:.*]] = linalg.add
/// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT:.*]] :
/// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[ADD_OUT:.*]], %[[EXP_OUT:.*]] :
Expand Down Expand Up @@ -353,4 +353,57 @@ module {
/// CHECK: return %[[PACK_OUT]]
return %pack : tensor<1x1x128x32x32xbf16>
}
}

// -----

module {
// CHECK: func.func @fuse_generic_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} {
/// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
%0 = tensor.empty() : tensor<2x2x16x16xf32>
%pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32>
/// CHECK: %[[EMPTY_OUT_1:.*]] = tensor.empty
%1 = tensor.empty() : tensor<2x16x16xf32>
/// CHECK: %[[FIRST_MATMUL_OUT:.*]] = scf.forall (%{{.*}}) in (2)
/// CHECK: %[[EXTRACT_SLICE_0:.*]] = tensor.extract_slice %[[ARG0]]{{.*}} [16, 32]
/// CHECK: %[[EXTRACT_SLICE_1:.*]] = tensor.extract_slice %[[EMPTY_OUT_0]]{{.*}} [1, 2, 16, 16]
/// CHECK: %[[PACK_OUT:.*]] = tensor.pack %[[EXTRACT_SLICE_0]]
/// CHECK: %[[EXTRACT_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]]{{.*}} [2, 16, 16]
/// CHECK: %[[MATMUL_OUT_0:.*]] = linalg.generic {{.*}} ins(%[[PACK_OUT]], %[[EXTRACT_SLICE_2]] :
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>], iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %arg1 : tensor<2x2x16x16xf32>, tensor<2x16x16xf32>) outs(%1 : tensor<2x16x16xf32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%9 = arith.mulf %in, %in_3 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<2x16x16xf32>
/// CHECK: scf.forall.in_parallel
/// CHECK: tensor.parallel_insert_slice
/// CHECK: }
/// CHECK: %[[EMPTY_OUT_2:.*]] = tensor.empty
/// CHECK: %[[EMPTY_OUT_3:.*]] = tensor.empty
%3 = tensor.empty() : tensor<2x4x16x16xf32>
/// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 4)
/// CHECK: %[[EXTRACT_SLICE_3:.*]] = tensor.extract_slice %[[FIRST_MATMUL_OUT]]{{.*}} [1, 16, 16]
/// CHECK: %[[EXTRACT_SLICE_4:.*]] = tensor.extract_slice %[[ARG2]]{{.*}} [1, 16, 16]
/// CHECK: %[[MATMUL_OUT_1:.*]] = linalg.generic {{.*}} ins(%[[EXTRACT_SLICE_3]], %[[EXTRACT_SLICE_4]] :
/// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack %[[MATMUL_OUT_1]]
%4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%2, %arg2 : tensor<2x16x16xf32>, tensor<4x16x16xf32>) outs(%3 : tensor<2x4x16x16xf32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%9 = arith.mulf %in, %in_3 : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<2x4x16x16xf32>
/// CHECK: scf.forall.in_parallel
/// CHECK: tensor.parallel_insert_slice
/// CHECK: tensor.parallel_insert_slice
/// CHECK: }
%5 = tensor.empty() : tensor<32x64xf32>
%unpack = tensor.unpack %4 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %5 : tensor<2x4x16x16xf32> -> tensor<32x64xf32>
/// CHECK: return %[[FINAL_RESULT]]#1
return %unpack : tensor<32x64xf32>
}
}