Skip to content

[Blackwell] Refactor/slightly generalize warp specialization #6597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
7fc06c0
start introducing tokens
Mogball Apr 15, 2025
4023111
hoist tmem alloc
Mogball Apr 16, 2025
56e1c8c
cleanup
Mogball Apr 16, 2025
225f241
add test for sinking into conditional
Mogball Apr 16, 2025
94d991c
fix tests and some bugs
Mogball Apr 16, 2025
9b742a0
fix repl token
Mogball Apr 16, 2025
720a700
fix aws test
Mogball Apr 16, 2025
107daf0
fix test
Mogball Apr 16, 2025
19e37d4
fixing tests, remove TMEM tokens
Mogball Apr 16, 2025
c09d3c5
separate pass for removing TMEM tokens
Mogball Apr 17, 2025
d2acd19
fix tests
Mogball Apr 17, 2025
8521ccf
schedule loops
Mogball Apr 17, 2025
db65c51
bench
Mogball Apr 17, 2025
4a9a5ed
fix compile only test
Mogball Apr 17, 2025
c7cd1c4
delete dead code
Mogball Apr 17, 2025
ad6d7de
unused forward decl
Mogball Apr 17, 2025
303df26
Merge remote-tracking branch 'origin/main' into mogball/tmem_toks
Mogball Apr 17, 2025
7bc72fc
[Blackwell] Support DescriptorLoadOp when deciding to use shared memo…
csullivan Apr 14, 2025
7e608e3
[Bench][Blackwell] Support optional scale TMAs in warp specialization…
csullivan Apr 18, 2025
0dffe75
hoisttmemalloc checks that tokens are present
Mogball Apr 21, 2025
8aa9165
add doc about tokens to op definitions
Mogball Apr 21, 2025
1349758
Merge branch 'mogball/tmem_toks' into mogball/fmha
Mogball Apr 21, 2025
54884c4
Merge remote-tracking branch 'origin/csullivan/support_block_scales_i…
Mogball Apr 21, 2025
7767409
simplify util
Mogball Apr 21, 2025
9d4ebe2
refactor LoadMMASpecialization to support any number of loads
Mogball Apr 21, 2025
9e018e5
fix handling cycle in user partition
Mogball Apr 21, 2025
4a72ab8
refactor loads into loadgroups
Mogball Apr 22, 2025
695eb2a
Merge branch 'main' into mogball/tmem_toks
Mogball Apr 22, 2025
d63cb82
fix
Mogball Apr 22, 2025
bffcb5b
cleanup packLL utilities
Mogball Apr 23, 2025
e8f28b4
Merge branch 'main' into mogball/tmem_toks
Mogball Apr 23, 2025
d6a78f4
WIP refactoring...
Mogball Apr 24, 2025
63b7da0
Merge remote-tracking branch 'origin/main' into mogball/fmha
Mogball Apr 24, 2025
9f1fc29
Revert "[Blackwell] Support DescriptorLoadOp when deciding to use sha…
Mogball Apr 24, 2025
b9b73f9
Revert "[Bench][Blackwell] Support optional scale TMAs in warp specia…
Mogball Apr 24, 2025
fa8b255
Merge branch 'main' into mogball/tmem_toks
Mogball Apr 24, 2025
c5f8cb4
fix conflict
Mogball Apr 24, 2025
8342d7a
Merge branch 'mogball/tmem_toks' into mogball/fmha
Mogball Apr 24, 2025
2f74f02
loads work
Mogball Apr 24, 2025
6e5b526
mmas are a pain
Mogball Apr 25, 2025
2af0311
Merge remote-tracking branch 'origin/main' into mogball/tmem_toks
Mogball Apr 25, 2025
5860abe
Merge branch 'mogball/tmem_toks' into mogball/fmha
Mogball Apr 25, 2025
1f507ac
done but does it work?
Mogball Apr 25, 2025
dda423b
it deadlocks
Mogball Apr 25, 2025
d965b3b
works but ends too early
Mogball Apr 25, 2025
b4cb1af
fix regular matmul
Mogball Apr 25, 2025
1140e1c
fix
Mogball Apr 26, 2025
4eff58e
fixed
Mogball Apr 26, 2025
38072a6
forgot to handle P
Mogball Apr 26, 2025
8358c0d
fix optzn
Mogball Apr 26, 2025
50c3b55
dep dialect
Mogball Apr 28, 2025
b4a8612
savepoint: OAI benchmarks look good
Mogball Apr 28, 2025
2b90525
rename op
Mogball Apr 28, 2025
a912113
put scales into smem
Mogball Apr 28, 2025
698e94f
put local load in user partition
Mogball Apr 28, 2025
7ffbd86
add another test
Mogball Apr 28, 2025
ee8eda3
add another test
Mogball Apr 28, 2025
93e42ba
Merge remote-tracking branch 'origin/main' into HEAD
Mogball Apr 29, 2025
4fba3f6
refactor pipelineMMA
Mogball Apr 29, 2025
1615916
handle peeled wait
Mogball Apr 29, 2025
d8887ac
Merge branch 'main' into mogball/fmha
Mogball Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 7 additions & 80 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,89 +718,16 @@ void storeDistributedToShared(
RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmStruct.getType()) ||
isa<LLVM::LLVMPointerType>(llvmStruct.getType()))
return {llvmStruct};
ArrayRef<Type> types =
cast<LLVM::LLVMStructType>(llvmStruct.getType()).getBody();
SmallVector<Value> results(types.size());
auto b = TritonLLVMOpBuilder(loc, rewriter);
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = b.extract_val(type, llvmStruct, i);
}
return results;
}

inline Value packLLElements(Location loc,
const LLVMTypeConverter *typeConverter,
ValueRange resultVals, RewriterBase &rewriter,
Type type) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (!structType) {
assert(resultVals.size() == 1);
return *resultVals.begin();
}

auto elementTypes = structType.getBody();
if (elementTypes.size() != resultVals.size()) {
emitError(loc) << " size mismatch when packing elements for LLVM struct"
<< " expected " << elementTypes.size() << " but got "
<< resultVals.size();
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
auto b = TritonLLVMOpBuilder(loc, rewriter);
for (const auto &v : llvm::enumerate(resultVals)) {
if (!v.value()) {
emitError(loc)
<< "cannot insert null values into struct, but tried to insert"
<< v.value();
}
if (v.value().getType() != elementTypes[v.index()]) {
LDBG("type " << type << " structType " << structType);
LDBG("value " << v.value());
emitError(loc) << "invalid element type in packLLElements. Expected "
<< elementTypes[v.index()] << " but got "
<< v.value().getType();
}
llvmStruct = b.insert_val(structType, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter);

inline SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
RewriterBase &rewriter) {
assert(bool(llvmVec) && "cannot unpack null value");
if (llvmVec.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmVec.getType()) ||
isa<LLVM::LLVMPointerType>(llvmVec.getType()))
return {llvmVec};
Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter,
ValueRange resultVals, RewriterBase &rewriter, Type type);

auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> results;
for (int i = 0; i < cast<VectorType>(llvmVec.getType()).getNumElements();
i++) {
results.push_back(b.extract_element(llvmVec, b.i32_val(i)));
}
return results;
}
SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
RewriterBase &rewriter);

inline Value packLLVector(Location loc, ValueRange vals,
RewriterBase &rewriter) {
assert(vals.size() > 0);
auto vecType = vec_ty(vals[0].getType(), vals.size());
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value vec = b.undef(vecType);
for (int i = 0; i < vals.size(); i++) {
vec = b.insert_element(vec, vals[i], b.i32_val(i));
}
return vec;
}
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);

inline bool
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class WarpSchedule {
ArrayRef<Operation *> getOps() const { return ops; }

void insert(Operation *op) { ops.push_back(op); }
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }

private:
void setIndex(int idx) { this->idx = idx; }
Expand All @@ -62,6 +63,8 @@ class WarpSchedule {
Partition *addPartition(unsigned stage);
// Give each partition a new index and order. The indices must be unique.
void reorderPartitions(ArrayRef<unsigned> order);
// Update the op to partition mapping.
void updatePartitions();

// Get the partition the op belongs to.
Partition *getPartition(Operation *op);
Expand Down Expand Up @@ -115,6 +118,9 @@ class WarpSchedule {
scf::ForOp loop, const Partition *partition,
function_ref<void(OpResult, OpOperand &, unsigned)> callback) const;

// Debug dump the schedule.
LLVM_DUMP_METHOD void dump() const;

private:
// Partitions are numbered [0, N).
SmallVector<std::unique_ptr<Partition>> partitions;
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "
and async MMAs into separate partitions.
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];

let options = [
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ DenseMap<Operation *, int> deserializeLatencies(Operation *op);
Value createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
unsigned numBuffers);
// Create an allocation and init the mbarriers.
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers);
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers,
int arriveCount = 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part of refactoring? Or is it addressing a separate issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the refactor. Load groups can have multiple consumers

// Create an allocation that can hold distance number of tensor shapes.
Value createAlloc(scf::ForOp forOp, RankedTensorType ty, Location loc,
gpu::SharedEncodingTrait sharedEnc, unsigned distance);
Expand Down
80 changes: 80 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,86 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
llvm::report_fatal_error("Failed to emit transfer from register to shared");
}

SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmStruct.getType()) ||
isa<LLVM::LLVMPointerType>(llvmStruct.getType()))
return {llvmStruct};
ArrayRef<Type> types =
cast<LLVM::LLVMStructType>(llvmStruct.getType()).getBody();
SmallVector<Value> results(types.size());
auto b = TritonLLVMOpBuilder(loc, rewriter);
for (unsigned i = 0; i < types.size(); ++i) {
Type type = types[i];
results[i] = b.extract_val(type, llvmStruct, i);
}
return results;
}

Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter,
ValueRange resultVals, RewriterBase &rewriter, Type type) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (!structType) {
assert(resultVals.size() == 1);
return *resultVals.begin();
}

auto elementTypes = structType.getBody();
if (elementTypes.size() != resultVals.size()) {
emitError(loc) << " size mismatch when packing elements for LLVM struct"
<< " expected " << elementTypes.size() << " but got "
<< resultVals.size();
llvm::report_fatal_error(
"size mismatch when packing elements for LLVM struct");
}
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
auto b = TritonLLVMOpBuilder(loc, rewriter);
for (auto [i, value] : llvm::enumerate(resultVals)) {
assert(value && "unexpected null value");
if (value.getType() != elementTypes[i]) {
LDBG("type " << type << " structType " << structType);
LDBG("value " << value);
emitError(loc) << "invalid element type in packLLElements. Expected "
<< elementTypes[i] << " but got " << value.getType();
llvm::report_fatal_error(
"element type mismatch when packing elements for LLVM struct");
}
llvmStruct = b.insert_val(structType, llvmStruct, value, i);
}
return llvmStruct;
}

SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
RewriterBase &rewriter) {
assert(bool(llvmVec) && "cannot unpack null value");
if (llvmVec.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmVec.getType()) ||
isa<LLVM::LLVMPointerType>(llvmVec.getType()))
return {llvmVec};

auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> results;
for (int i = 0; i < cast<VectorType>(llvmVec.getType()).getNumElements();
i++) {
results.push_back(b.extract_element(llvmVec, b.i32_val(i)));
}
return results;
}

Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter) {
assert(vals.size() > 0);
auto vecType = vec_ty(vals[0].getType(), vals.size());
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value vec = b.undef(vecType);
for (int i = 0; i < vals.size(); i++) {
vec = b.insert_element(vec, vals[i], b.i32_val(i));
}
return vec;
}

SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
RankedTensorType type) {
MLIRContext *ctx = layout.getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {
std::pair<OpResult, int64_t>
mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) {
int64_t distance = 0;
DenseSet<Value> seen;
while (auto arg = dyn_cast<BlockArgument>(value)) {
// Ignore implicit captures.
if (arg.getOwner() != forOp.getBody())
Expand All @@ -297,6 +298,8 @@ mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) {
return {nullptr, 0};
++distance;
value = forOp.getYieldedValues()[arg.getArgNumber() - 1];
if (!seen.insert(value).second)
return {nullptr, 0};
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also doesn't feel like refactoring :]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the refactoring exposed a bug :P

return {cast<OpResult>(value), distance};
}
Expand Down Expand Up @@ -358,14 +361,15 @@ Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
}

// Create an allocation and init the mbarriers.
Value mlir::triton::createBarrierAlloc(scf::ForOp forOp, int numBarriers) {
Value mlir::triton::createBarrierAlloc(scf::ForOp forOp, int numBarriers,
int arriveCount) {
ImplicitLocOpBuilder rewriter(forOp.getLoc(), forOp);

Value barrierAlloc =
createScalarAlloc(rewriter, rewriter.getI64Type(), numBarriers);
for (unsigned i = 0; i < numBarriers; i++) {
Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i);
rewriter.create<ttng::InitBarrierOp>(barrierView, 1);
rewriter.create<ttng::InitBarrierOp>(barrierView, arriveCount);
}
// Invalidate and deallocate the barriers.
rewriter.setInsertionPointAfter(forOp);
Expand Down
18 changes: 10 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,19 @@ tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
SmallVector<SmallVector<std::tuple<Operation *, int, Cluster>>, 8>
orderClusters(clusters.size());
for (auto &op : forOp.getBody()->without_terminator()) {
if (opToStageAndCluster.count(&op) == 0) {
auto it = opToStageAndCluster.find(&op);
if (it == opToStageAndCluster.end()) {
continue;
}
assert(opToStageAndCluster[&op].first < numStages &&
"Op with invalid stage!");
int clusterId = *opToStageAndCluster[&op].second;
assert(clusterId == std::distance(clusters.begin(),
opToStageAndCluster[&op].second) &&
auto [stage, cluster] = it->second;
if (cluster == Cluster{}) {
continue;
}
assert(stage < numStages && "Op with invalid stage!");
int clusterId = *cluster;
assert(clusterId == std::distance(clusters.begin(), cluster) &&
"Cluster ID mismatch!");
orderClusters[clusterId].push_back(make_tuple(
&op, opToStageAndCluster[&op].first, opToStageAndCluster[&op].second));
orderClusters[clusterId].push_back(make_tuple(&op, stage, cluster));
}
SmallVector<std::tuple<Operation *, int, Cluster>> opsInOrder;
for (int i = 0; i < orderClusters.size(); i++) {
Expand Down
Loading
Loading