-
Notifications
You must be signed in to change notification settings - Fork 2k
[TritonNVIDIAGPU] Add dependency tokens to TMEM ops #6520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably worth describing a bit the expected semantic of those tokens.
Is it that when the token are there there is an explicit token dependency between any TMEM operations that alias? And when the token are not there the semantic falls back to just normal aliasing?
); | ||
let results = (outs | ||
TT_Tensor:$result, | ||
Optional<TTG_AsyncToken>:$token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does load returns a token even though it doesn't modify TMem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that if the ops have tokens, they could be marked Pure and the IR will still be semantically correct. This means that WAR dependencies have to be represented by passing the result token from the load as an operand to the next store.
For example, this token is how the load-store forwarding pattern in HoistTMEMAlloc is implemented: pattern matching tmem_store(tmem_load)
where the compiler knows the store modifies the memory after the load due to the token.
Yes exactly. The tokens act as encoding known modref information in the IR, and when they are present, can be used to refine regular alias analysis. I'll add a little blurb to each op describing this. |
From looking at the code I am not sure if "falls back to just normal aliasing" is correct? In HoistTMEMAlloc we now rely exclusively on the tokens to check for aliasing, right? |
HoistTMEMAlloc does expect the tokens to be present. The tokens have to be discarded after multibuffering, however, since the encoded modref information becomes too strong (the tokens would indicate modref between operations operating on different subslices, and this will break pipelining and warp specialization). In practice, tokens will always be present during HoistTMEMAlloc, but I will modify it to fail if they are not. |
Chatted with @pawelszczerbuk about this offline. HoistTMEMAlloc will explicitly check for the tokens on the ops using a CRTP subclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The Triton middle-end has perfect dependency+modref information about TMEM (and shared memory) because it is introduced by the middle-end by expanding chains of SSA ops. E.g. `HoistTMEMAlloc` is essentially a form of reg-2-mem for MMA accumulators. Despite this, the dependency and alias analysis needed by `HoistTMEMAlloc`, warp specialization, and the pipeliner rely on ad-hoc checks that are not always correct and which are becoming increasingly complex. Instead of building stronger memory analysis, we can just not discard the information the compiler already has. This PR adds tokens to all the ops that touch TMEM (except `TMEMCopyOp`, since it is not used in the middle-end), and acts as a form of MemorySSA (memory variable lattice encoded in the IR), and leverages them throughout the middle-end to check aliasing, modref, etc. information instead of scanning the IR. Consequently, the transformations are more robust and easier to maintain, at the cost of extra book-keeping that is necessary. This will greatly simplify the dependence analysis needed by more complex warp specialization, and help with composing warp specialization with the pipeliner(cc @htyu @manman-ren). There would be a pretty big performance cliff if this PR was wrong (failed to pipeline/warp specialize), so I sanity checked that it did not break pipelining. ### Performance numbers after ``` ├─ 703.378 976.992 matmul_kernel [M=8192, N=8192, K=512] ├─ 936.461 733.821 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 938.393 732.310 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512] ├─ 856.351 802.468 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 785.072 875.327 matmul_kernel_tma [M=8192, N=8192, K=512] ├─ 1024.165 670.981 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] ├─ 1125.056 610.810 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512] ├─ 800.940 857.986 matmul_kernel_tma_ws [M=8192, N=8192, K=512] ``` ``` fused-attention-batch4-head32-d64-fwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 183.032906 176.540661 1 2048.0 384.363999 417.633483 2 4096.0 471.816004 511.814693 3 8192.0 519.752669 566.761880 4 16384.0 545.707761 595.042579 fused-attention-batch4-head32-d64-fwd-causal=False: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 364.631059 364.685641 1 2048.0 492.108137 536.102664 2 4096.0 532.795804 580.166599 3 8192.0 550.670842 599.591255 4 16384.0 559.480705 608.551411 fused-attention-batch4-head32-d64-bwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 144.731066 152.721176 1 2048.0 234.101200 234.195236 2 4096.0 293.602665 293.519568 3 8192.0 331.644550 331.388321 4 16384.0 355.252999 354.861517 ``` ``` Problem Shape = 8192x8192x512 └─ 974.209 705.458 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512] ``` ### Performance numbers before ``` ├─ 708.163 970.391 matmul_kernel [M=8192, N=8192, K=512] ├─ 935.792 734.346 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512] ├─ 922.666 744.793 matmul_kernel_descriptor_persistent_ws [M=8192, N=8192, K=512] ├─ 856.643 802.195 matmul_kernel_persistent [M=8192, N=8192, K=512] ├─ 792.424 867.206 matmul_kernel_tma [M=8192, N=8192, K=512] ├─ 1020.997 673.063 matmul_kernel_tma_persistent [M=8192, N=8192, K=512] ├─ 1134.083 605.948 matmul_kernel_tma_persistent_ws [M=8192, N=8192, K=512] ├─ 799.650 859.369 matmul_kernel_tma_ws [M=8192, N=8192, K=512] ``` ``` fused-attention-batch4-head32-d64-fwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 181.507652 183.077756 1 2048.0 384.836411 416.908797 2 4096.0 471.260742 512.709282 3 8192.0 519.896730 566.172554 4 16384.0 545.181917 595.246382 fused-attention-batch4-head32-d64-fwd-causal=False: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 368.266771 373.516950 1 2048.0 492.137719 535.968650 2 4096.0 533.092876 580.134559 3 8192.0 550.571575 599.455669 4 16384.0 559.555689 608.442981 fused-attention-batch4-head32-d64-bwd-causal=True: N_CTX Triton [FP16] Triton [FP8] 0 1024.0 151.081525 155.745186 1 2048.0 234.359406 234.108984 2 4096.0 293.584945 293.689437 3 8192.0 331.633380 331.669234 4 16384.0 355.077635 354.963313 ``` ``` Problem Shape = 8192x8192x512 └─ 972.794 706.484 block_scaled_matmul_kernel_nvfp4 [M=8192, N=8192, K=512] ```
The Triton middle-end has perfect dependency+modref information about TMEM (and shared memory) because it is introduced by the middle-end by expanding chains of SSA ops. E.g.
HoistTMEMAlloc
is essentially a form of reg-2-mem for MMA accumulators.Despite this, the dependency and alias analysis needed by
HoistTMEMAlloc
, warp specialization, and the pipeliner rely on ad-hoc checks that are not always correct and which are becoming increasingly complex. Instead of building stronger memory analysis, we can just not discard the information the compiler already has.This PR adds tokens to all the ops that touch TMEM (except
TMEMCopyOp
, since it is not used in the middle-end), and acts as a form of MemorySSA (memory variable lattice encoded in the IR), and leverages them throughout the middle-end to check aliasing, modref, etc. information instead of scanning the IR. Consequently, the transformations are more robust and easier to maintain, at the cost of extra book-keeping that is necessary.This will greatly simplify the dependence analysis needed by more complex warp specialization, and help with composing warp specialization with the pipeliner(cc @htyu @manman-ren).
There would be a pretty big performance cliff if this PR was wrong (failed to pipeline/warp specialize), so I sanity checked that it did not break pipelining.
Performance numbers after
Performance numbers before