-
Notifications
You must be signed in to change notification settings - Fork 2k
[Blackwell] Refactor/slightly generalize warp specialization #6597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ry for scales This ensures scales produced by TMA are elligible for transfer to tensor memory in later lowering. git-pr-chain: csullivan/support_desc_load_tmem_copy
… for tl.dot_scaled This enables automatic warp specialization for block scaled workloads. git-pr-chain: csullivan/support_block_scales_in_warp_spec
…n_warp_spec' into mogball/fmha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages) | ||
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) | ||
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | ||
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed this same change for getting FA to run with WarpSpec on Blackwell.
@@ -100,7 +100,8 @@ DenseMap<Operation *, int> deserializeLatencies(Operation *op); | |||
Value createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type, | |||
unsigned numBuffers); | |||
// Create an allocation and init the mbarriers. | |||
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers); | |||
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers, | |||
int arriveCount = 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part of refactoring? Or is it addressing a separate issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is part of the refactor. Load groups can have multiple consumers
@@ -297,6 +298,8 @@ mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) { | |||
return {nullptr, 0}; | |||
++distance; | |||
value = forOp.getYieldedValues()[arg.getArgNumber() - 1]; | |||
if (!seen.insert(value).second) | |||
return {nullptr, 0}; | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also doesn't feel like refactoring :]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the refactoring exposed a bug :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
This is basically a rewrite of LoadMMASpecialization to cleanly separate the warp specialization of loads and MMAs into discrete steps. This allows warp specialization of loads and MMAs separate from each other, and supports an arbitrary number of load groups and MMAs. This still places all loads and MMAs in the same partitions.
In addition, this PR separates the actual partition assignment from the multibuffering and loop lowering step, like SWP. This should make it easier to tweak partitioning strategies.