Skip to content

[TritonNVIDIAGPU] Add dependency tokens to TMEM ops #6520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 29, 2025
Merged

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Apr 17, 2025

The Triton middle-end has perfect dependency+modref information about TMEM (and shared memory) because it is introduced by the middle-end by expanding chains of SSA ops. E.g. HoistTMEMAlloc is essentially a form of reg-2-mem for MMA accumulators.

Despite this, the dependency and alias analysis needed by HoistTMEMAlloc, warp specialization, and the pipeliner rely on ad-hoc checks that are not always correct and which are becoming increasingly complex. Instead of building stronger memory analysis, we can just not discard the information the compiler already has.

This PR adds tokens to all the ops that touch TMEM (except TMEMCopyOp, since it is not used in the middle-end), and acts as a form of MemorySSA (memory variable lattice encoded in the IR), and leverages them throughout the middle-end to check aliasing, modref, etc. information instead of scanning the IR. Consequently, the transformations are more robust and easier to maintain, at the cost of extra book-keeping that is necessary.

This will greatly simplify the dependence analysis needed by more complex warp specialization, and help with composing warp specialization with the pipeliner(cc @htyu @manman-ren).

There would be a pretty big performance cliff if this PR was wrong (failed to pipeline/warp specialize), so I sanity checked that it did not break pipelining.

Performance numbers after

├─ 703.378 976.992 matmul_kernel [M=8192, N=8192, K=512]
├─ 936.461 733.821 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 938.393 732.310 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.351 802.468 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 785.072 875.327 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1024.165 670.981 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1125.056 610.810 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 800.940 857.986 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     183.032906    176.540661
1   2048.0     384.363999    417.633483
2   4096.0     471.816004    511.814693
3   8192.0     519.752669    566.761880
4  16384.0     545.707761    595.042579
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     364.631059    364.685641
1   2048.0     492.108137    536.102664
2   4096.0     532.795804    580.166599
3   8192.0     550.670842    599.591255
4  16384.0     559.480705    608.551411
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     144.731066    152.721176
1   2048.0     234.101200    234.195236
2   4096.0     293.602665    293.519568
3   8192.0     331.644550    331.388321
4  16384.0     355.252999    354.861517
Problem Shape = 8192x8192x512
└─ 974.209 705.458 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]

Performance numbers before

├─ 708.163 970.391 matmul_kernel [M=8192, N=8192, K=512]
├─ 935.792 734.346 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 922.666 744.793 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.643 802.195 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 792.424 867.206 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1020.997 673.063 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1134.083 605.948 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 799.650 859.369 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     181.507652    183.077756
1   2048.0     384.836411    416.908797
2   4096.0     471.260742    512.709282
3   8192.0     519.896730    566.172554
4  16384.0     545.181917    595.246382
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     368.266771    373.516950
1   2048.0     492.137719    535.968650
2   4096.0     533.092876    580.134559
3   8192.0     550.571575    599.455669
4  16384.0     559.555689    608.442981
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     151.081525    155.745186
1   2048.0     234.359406    234.108984
2   4096.0     293.584945    293.689437
3   8192.0     331.633380    331.669234
4  16384.0     355.077635    354.963313
Problem Shape = 8192x8192x512
└─ 972.794 706.484 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]

@Mogball Mogball changed the title [WIP][DNR] TMEM dependency tokens [TritonNVIDIAGPU] Add dependency tokens to TMEM ops Apr 17, 2025
@Mogball Mogball marked this pull request as ready for review April 17, 2025 17:51
@Mogball Mogball requested a review from ptillet as a code owner April 17, 2025 17:51
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth describing a bit the expected semantic of those tokens.
Is it that when the token are there there is an explicit token dependency between any TMEM operations that alias? And when the token are not there the semantic falls back to just normal aliasing?

);
let results = (outs
TT_Tensor:$result,
Optional<TTG_AsyncToken>:$token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does load returns a token even though it doesn't modify TMem?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that if the ops have tokens, they could be marked Pure and the IR will still be semantically correct. This means that WAR dependencies have to be represented by passing the result token from the load as an operand to the next store.

For example, this token is how the load-store forwarding pattern in HoistTMEMAlloc is implemented: pattern matching tmem_store(tmem_load) where the compiler knows the store modifies the memory after the load due to the token.

@Mogball
Copy link
Collaborator Author

Mogball commented Apr 20, 2025

It's probably worth describing a bit the expected semantic of those tokens. Is it that when the token are there there is an explicit token dependency between any TMEM operations that alias? And when the token are not there the semantic falls back to just normal aliasing?

Yes exactly. The tokens act as encoding known modref information in the IR, and when they are present, can be used to refine regular alias analysis. I'll add a little blurb to each op describing this.

@pawelszczerbuk
Copy link
Contributor

pawelszczerbuk commented Apr 21, 2025

It's probably worth describing a bit the expected semantic of those tokens. Is it that when the token are there there is an explicit token dependency between any TMEM operations that alias? And when the token are not there the semantic falls back to just normal aliasing?

Yes exactly. The tokens act as encoding known modref information in the IR, and when they are present, can be used to refine regular alias analysis. I'll add a little blurb to each op describing this.

From looking at the code I am not sure if "falls back to just normal aliasing" is correct? In HoistTMEMAlloc we now rely exclusively on the tokens to check for aliasing, right?
If this is true, I wonder if having the tokens as optional is correct, since we rely on them for correctness of transformations?

@Mogball
Copy link
Collaborator Author

Mogball commented Apr 21, 2025

It's probably worth describing a bit the expected semantic of those tokens. Is it that when the token are there there is an explicit token dependency between any TMEM operations that alias? And when the token are not there the semantic falls back to just normal aliasing?

Yes exactly. The tokens act as encoding known modref information in the IR, and when they are present, can be used to refine regular alias analysis. I'll add a little blurb to each op describing this.

From looking at the code I am not sure if "falls back to just normal aliasing" is correct? In HoistTMEMAlloc we now rely exclusively on the tokens to check for aliasing, right? If this is true, I wonder if having the tokens as optional is correct, since we rely on them for correctness of transformations?

HoistTMEMAlloc does expect the tokens to be present. The tokens have to be discarded after multibuffering, however, since the encoded modref information becomes too strong (the tokens would indicate modref between operations operating on different subslices, and this will break pipelining and warp specialization).

In practice, tokens will always be present during HoistTMEMAlloc, but I will modify it to fail if they are not.

@Mogball
Copy link
Collaborator Author

Mogball commented Apr 21, 2025

Chatted with @pawelszczerbuk about this offline. HoistTMEMAlloc will explicitly check for the tokens on the ops using a CRTP subclass

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Mogball Mogball merged commit f0530d2 into main Apr 29, 2025
8 checks passed
@Mogball Mogball deleted the mogball/tmem_toks branch April 29, 2025 01:05
FindHao pushed a commit to FindHao/triton that referenced this pull request Apr 30, 2025
The Triton middle-end has perfect dependency+modref information about
TMEM (and shared memory) because it is introduced by the middle-end by
expanding chains of SSA ops. E.g. `HoistTMEMAlloc` is essentially a form
of reg-2-mem for MMA accumulators.

Despite this, the dependency and alias analysis needed by
`HoistTMEMAlloc`, warp specialization, and the pipeliner rely on ad-hoc
checks that are not always correct and which are becoming increasingly
complex. Instead of building stronger memory analysis, we can just not
discard the information the compiler already has.

This PR adds tokens to all the ops that touch TMEM (except `TMEMCopyOp`,
since it is not used in the middle-end), and acts as a form of MemorySSA
(memory variable lattice encoded in the IR), and leverages them
throughout the middle-end to check aliasing, modref, etc. information
instead of scanning the IR. Consequently, the transformations are more
robust and easier to maintain, at the cost of extra book-keeping that is
necessary.

This will greatly simplify the dependence analysis needed by more
complex warp specialization, and help with composing warp specialization
with the pipeliner(cc @htyu @manman-ren).

There would be a pretty big performance cliff if this PR was wrong
(failed to pipeline/warp specialize), so I sanity checked that it did
not break pipelining.

### Performance numbers after

```
├─ 703.378 976.992 matmul_kernel [M=8192, N=8192, K=512]
├─ 936.461 733.821 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 938.393 732.310 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.351 802.468 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 785.072 875.327 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1024.165 670.981 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1125.056 610.810 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 800.940 857.986 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
```

```
fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     183.032906    176.540661
1   2048.0     384.363999    417.633483
2   4096.0     471.816004    511.814693
3   8192.0     519.752669    566.761880
4  16384.0     545.707761    595.042579
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     364.631059    364.685641
1   2048.0     492.108137    536.102664
2   4096.0     532.795804    580.166599
3   8192.0     550.670842    599.591255
4  16384.0     559.480705    608.551411
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     144.731066    152.721176
1   2048.0     234.101200    234.195236
2   4096.0     293.602665    293.519568
3   8192.0     331.644550    331.388321
4  16384.0     355.252999    354.861517
```

```
Problem Shape = 8192x8192x512
└─ 974.209 705.458 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]
```

### Performance numbers before

```
├─ 708.163 970.391 matmul_kernel [M=8192, N=8192, K=512]
├─ 935.792 734.346 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 922.666 744.793 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512]
├─ 856.643 802.195 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 792.424 867.206 matmul_kernel_tma [M=8192, N=8192, K=512]
├─ 1020.997 673.063 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
├─ 1134.083 605.948 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512]
├─ 799.650 859.369 matmul_kernel_tma_ws [M=8192, N=8192, K=512]
```

```
fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     181.507652    183.077756
1   2048.0     384.836411    416.908797
2   4096.0     471.260742    512.709282
3   8192.0     519.896730    566.172554
4  16384.0     545.181917    595.246382
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     368.266771    373.516950
1   2048.0     492.137719    535.968650
2   4096.0     533.092876    580.134559
3   8192.0     550.571575    599.455669
4  16384.0     559.555689    608.442981
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0     151.081525    155.745186
1   2048.0     234.359406    234.108984
2   4096.0     293.584945    293.689437
3   8192.0     331.633380    331.669234
4  16384.0     355.077635    354.963313
```

```
Problem Shape = 8192x8192x512
└─ 972.794 706.484 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512]
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants