From ca70f08835eee98747578cb0f37af2c033ecc60b Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 8 Oct 2024 08:09:04 -0700 Subject: [PATCH] Revert https://github.com/triton-lang/triton/pull/4784 (#4865) Duplicating loads will significantly increase the shared memory usage which is likely to cause out of memory problem. We should find an alternative to not have to duplicate shared memory allocations. --- .../Pipeliner/MatmulLoopPipeline.cpp | 98 +++---------------- test/TritonGPU/loop-pipeline.mlir | 13 +-- 2 files changed, 17 insertions(+), 94 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index dc5f395c6753..c8f050eb529b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -219,8 +219,9 @@ static void createTMAAsyncCopy( // encodings, raise assertion, since incompatible shared encoding has been // handled in splitLoadsForIncompatible. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; + incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -230,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) return std::nullopt; } else { if (!isa(user)) @@ -248,8 +250,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (attr != nullptr) - assert(attr == tempAttr && "incompatible shared encoding"); + if (attr != nullptr && attr != tempAttr) { + incompatible = true; + return std::nullopt; + } attr = tempAttr; } return attr; @@ -439,8 +443,13 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (auto dot = dyn_cast(use)) { + bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + // If we can't agree on a shared encoding skip pipelinig the load. + if (incompatible) + continue; } } else if (auto loadOp = dyn_cast(use)) { // The use of this loadOp is another loadOp. If the use is not in the @@ -476,83 +485,6 @@ assignMemoryLayouts(llvm::SmallVector> return loadToInfo; } -// Split users to groups, each group has the same shared encoding. -// If not all users are Dot encoding, return empty vector. -static DenseMap> -handleIncompatibleSharedEncoding(Operation *loadOp) { - DenseMap> loadGroups; - // Go through transitive uses of the loadOp in the same block. - for (Operation *user : loadOp->getUsers()) { - if (user->getBlock() != loadOp->getBlock()) - continue; - if (user->getNumResults() != 1) - return loadGroups; - - ttg::SharedEncodingAttr tempAttr; - if (auto memDesc = - dyn_cast(user->getResult(0).getType())) { - tempAttr = cast(memDesc.getEncoding()); - loadGroups[tempAttr].push_back(user); - } else { - if (!isa(user)) - return loadGroups; - auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()).getEncoding()); - if (!dotOpEnc) - return loadGroups; - auto srcTy = cast(loadOp->getResult(0).getType()); - auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( - loadOp->getContext(), dotOpEnc, srcTy.getShape(), - ttg::getOrder(srcTy.getEncoding()), - ttg::getCTALayout(srcTy.getEncoding()), - srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); - loadGroups[tempAttr].push_back(user); - } - } - return loadGroups; -} - -// Clone loads so each group of uses with same shared encoding will have a -// corresponding Load. -static void splitLoadsForIncompatible( - OpBuilder &builder, Operation *loadOp, - DenseMap> &lGroups) { - // The first group will use the original load, create new loads for other - // groups. - unsigned idx = 0; - builder.setInsertionPointAfter(loadOp); - for (auto pair : lGroups) { - SmallVector &group = pair.second; - if (idx++ == 0) - continue; - Operation *newLoad = builder.clone(*loadOp); - for (auto *user : group) { - user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0)); - } - } -} - -static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) { - // Get the list of all loads. - SmallVector loads; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (isa(op)) { - loads.push_back(&op); - } - } - OpBuilder builder(forOp); - for (auto *loadOp : loads) { - auto lGroups = handleIncompatibleSharedEncoding(loadOp); - LDBG("groups with different encoding: " << lGroups.size() << " " - << *loadOp); - if (lGroups.size() > 1) - splitLoadsForIncompatible(builder, loadOp, lGroups); - } -} - static llvm::MapVector scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { @@ -1106,8 +1038,6 @@ static void invalidateBarriers(OpBuilder &builder, bool mlir::triton::preProcessLoopAndGetSchedule( scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { - splitLoadsWithIncompatibleEncoding(forOp); - // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index fca72ebda713..adaf515c800b 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -844,16 +844,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding - // AMD-NOT: alloc - // AMD: scf.for - // CHECK: local_alloc - // CHECK: local_alloc - // CHECK: scf.for - // CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1 - // CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0 - // CHECK: tt.dot - // CHECK: tt.trans %arg + // check that the load didn't get pipelined. + // COMMON-NOT: alloc + // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>