Skip to content

Commit

Permalink
Revert #4784 (#4865)
Browse files Browse the repository at this point in the history
Duplicating loads will significantly increase the shared memory usage
which is likely to cause out of memory problem. We should find an
alternative to not have to duplicate shared memory allocations.
  • Loading branch information
ThomasRaoux authored Oct 8, 2024
1 parent 53166ef commit ca70f08
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 94 deletions.
98 changes: 14 additions & 84 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ static void createTMAAsyncCopy(
// encodings, raise assertion, since incompatible shared encoding has been
// handled in splitLoadsForIncompatible.
static std::optional<ttg::SharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val) {
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
ttg::SharedEncodingAttr attr;
incompatible = false;
for (Operation *user : val.getUsers()) {
ttg::SharedEncodingAttr tempAttr;
if (user->getNumResults() != 1)
Expand All @@ -230,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
// First time we find a shared encoding in the chain, save it and try to
// use it if it is compatible with the other users.
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value())
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible)
.has_value())
return std::nullopt;
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
Expand All @@ -248,8 +250,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
bitWidth, /*needTrans=*/false);
}
// Check that the shared encodings needed by the users are compatible.
if (attr != nullptr)
assert(attr == tempAttr && "incompatible shared encoding");
if (attr != nullptr && attr != tempAttr) {
incompatible = true;
return std::nullopt;
}
attr = tempAttr;
}
return attr;
Expand Down Expand Up @@ -439,8 +443,13 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
loadInfo.sharedEncoding =
getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr);
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
bool incompatible = false;
loadInfo.sharedEncoding =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);
getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
.value_or(nullptr);
// If we can't agree on a shared encoding skip pipelinig the load.
if (incompatible)
continue;
}
} else if (auto loadOp = dyn_cast<tt::LoadOp>(use)) {
// The use of this loadOp is another loadOp. If the use is not in the
Expand Down Expand Up @@ -476,83 +485,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
return loadToInfo;
}

// Split users to groups, each group has the same shared encoding.
// If not all users are Dot encoding, return empty vector.
static DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>>
handleIncompatibleSharedEncoding(Operation *loadOp) {
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> loadGroups;
// Go through transitive uses of the loadOp in the same block.
for (Operation *user : loadOp->getUsers()) {
if (user->getBlock() != loadOp->getBlock())
continue;
if (user->getNumResults() != 1)
return loadGroups;

ttg::SharedEncodingAttr tempAttr;
if (auto memDesc =
dyn_cast<triton::MemDescType>(user->getResult(0).getType())) {
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
loadGroups[tempAttr].push_back(user);
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
return loadGroups;
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
if (!dotOpEnc)
return loadGroups;
auto srcTy = cast<TensorOrMemDesc>(loadOp->getResult(0).getType());
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SharedEncodingAttr::get(
loadOp->getContext(), dotOpEnc, srcTy.getShape(),
ttg::getOrder(srcTy.getEncoding()),
ttg::getCTALayout(srcTy.getEncoding()),
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
loadGroups[tempAttr].push_back(user);
}
}
return loadGroups;
}

// Clone loads so each group of uses with same shared encoding will have a
// corresponding Load.
static void splitLoadsForIncompatible(
OpBuilder &builder, Operation *loadOp,
DenseMap<ttg::SharedEncodingAttr, SmallVector<Operation *>> &lGroups) {
// The first group will use the original load, create new loads for other
// groups.
unsigned idx = 0;
builder.setInsertionPointAfter(loadOp);
for (auto pair : lGroups) {
SmallVector<Operation *> &group = pair.second;
if (idx++ == 0)
continue;
Operation *newLoad = builder.clone(*loadOp);
for (auto *user : group) {
user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0));
}
}
}

static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) {
// Get the list of all loads.
SmallVector<Operation *> loads;
for (Operation &op : forOp.getBody()->without_terminator()) {
if (isa<tt::LoadOp, tt::ExperimentalDescriptorLoadOp>(op)) {
loads.push_back(&op);
}
}
OpBuilder builder(forOp);
for (auto *loadOp : loads) {
auto lGroups = handleIncompatibleSharedEncoding(loadOp);
LDBG("groups with different encoding: " << lGroups.size() << " "
<< *loadOp);
if (lGroups.size() > 1)
splitLoadsForIncompatible(builder, loadOp, lGroups);
}
}

static llvm::MapVector<Operation *, LoadInfo>
scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule,
DenseSet<Operation *> &rootUsers, int numStages) {
Expand Down Expand Up @@ -1106,8 +1038,6 @@ static void invalidateBarriers(OpBuilder &builder,

bool mlir::triton::preProcessLoopAndGetSchedule(
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
splitLoadsWithIncompatibleEncoding(forOp);

// Schedule the loads and root ops (dot ops) in the loop. This will give us
// a scaffold for the final schedule.
DenseSet<Operation *> rootUsers;
Expand Down
13 changes: 3 additions & 10 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -844,16 +844,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
// check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding
// AMD-NOT: alloc
// AMD: scf.for
// CHECK: local_alloc
// CHECK: local_alloc
// CHECK: scf.for
// CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1
// CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0
// CHECK: tt.dot
// CHECK: tt.trans %arg
// check that the load didn't get pipelined.
// COMMON-NOT: alloc
// COMMON: scf.for
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
%19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
Expand Down

0 comments on commit ca70f08

Please sign in to comment.