Skip to content

[Blackwell] Refactor/slightly generalize warp specialization #6597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
May 1, 2025

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Apr 24, 2025

This is basically a rewrite of LoadMMASpecialization to cleanly separate the warp specialization of loads and MMAs into discrete steps. This allows warp specialization of loads and MMAs separate from each other, and supports an arbitrary number of load groups and MMAs. This still places all loads and MMAs in the same partitions.

In addition, this PR separates the actual partition assignment from the multibuffering and loop lowering step, like SWP. This should make it easier to tweak partitioning strategies.

Mogball and others added 30 commits April 16, 2025 20:29
…ry for scales

This ensures scales produced by TMA are elligible for transfer to tensor
memory in later lowering.

git-pr-chain: csullivan/support_desc_load_tmem_copy
… for tl.dot_scaled

This enables automatic warp specialization for block scaled workloads.

git-pr-chain: csullivan/support_block_scales_in_warp_spec
Base automatically changed from mogball/tmem_toks to main April 29, 2025 01:05
@Mogball Mogball changed the title [WIP][DNR] Refactor warp specialization [Blackwell] Refactor warp specialization Apr 29, 2025
@Mogball Mogball marked this pull request as ready for review April 29, 2025 02:19
@Mogball Mogball requested a review from ptillet as a code owner April 29, 2025 02:19
@Mogball Mogball requested a review from ThomasRaoux April 29, 2025 18:34
Copy link
Collaborator

@manman-ren manman-ren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed this same change for getting FA to run with WarpSpec on Blackwell.

@@ -100,7 +100,8 @@ DenseMap<Operation *, int> deserializeLatencies(Operation *op);
Value createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
unsigned numBuffers);
// Create an allocation and init the mbarriers.
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers);
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers,
int arriveCount = 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part of refactoring? Or is it addressing a separate issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the refactor. Load groups can have multiple consumers

@@ -297,6 +298,8 @@ mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) {
return {nullptr, 0};
++distance;
value = forOp.getYieldedValues()[arg.getArgNumber() - 1];
if (!seen.insert(value).second)
return {nullptr, 0};
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also doesn't feel like refactoring :]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the refactoring exposed a bug :P

@Mogball Mogball changed the title [Blackwell] Refactor warp specialization [Blackwell] Refactor/slightly generalize warp specialization Apr 30, 2025
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@Mogball Mogball merged commit 0719c00 into main May 1, 2025
8 checks passed
@Mogball Mogball deleted the mogball/fmha branch May 1, 2025 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants