-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performant backward Triton implementation with separated dkdv and dq kernels #122
Conversation
The switch between this new kernel bwd_prefill_split.py and bwd_prefill.py is controlled by an environment variable here
Though we discussed that an argument in a function to toggle it is preferred, passing the triton kernel wrapper into interfaces makes it hard to be controlled by argument imo, so I leave it as-is, but turn on split kernel by default. |
ea973eb
to
383d8c7
Compare
most of the changes related to removing/changing the existing debug messages is within a single commit and therefore dropped. |
814471e
to
7d83cd6
Compare
The problem introduced from 58941ed has been manually reverted from 8436dc7 and fixed in 814471e. The solution is simple, just do not assume any tensor sharing the same stride and therefore use the strides from the tensor itself. For example, Another change is to NOT TODO: |
7d83cd6
to
ada4bb8
Compare
ddd07df
to
bacc596
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is excellent
This PR introduces a performant version of the backward kernel implementation in Triton. It follows the same implementation strategy as the example in upstream Triton but with added functionality.
As expected, it surpasses the existing implementation but falls short to only half of the performance from the example mentioned above. Here's a quick summary of the performance at the time of submitting this PR.

Note that the tutorials/06-fused-attention.py assumes Q and KV have the same heads and sequence length, so it only is tested on those examples.
The performance investigation will follow through from another PR later