Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance update on the backward split kernel #127

Open
wants to merge 8 commits into
base: main_perf
Choose a base branch
from

Conversation

jtang10
Copy link

@jtang10 jtang10 commented Feb 12, 2025

image
This PR improves the backward kernel performance, bringing it closer to the tutorial/06-fused-attention.py performance. As it shows above, based on the benchmark in flash_attn/flash_attn_triton_amdwe, we are on average 90% of the tutorial, and achieves 60% performance improvement from the previous PR #122.

The improvement comes from the following places:

  1. tot Triton compiler, which greatly alleiviates the register spilling problem in the dkdv kernel.
  2. Merged dkdv and dq kernel. There is nothing changed algorithm-wise, simply merging two kernels together to share common variables and reduce launch latency, like the tutorial does.
  3. Turn on use_exp2 by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant