Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performant backward Triton implementation with separated dkdv and dq kernels #122

Merged
merged 34 commits into from
Feb 4, 2025

Conversation

jtang10
Copy link

@jtang10 jtang10 commented Jan 31, 2025

This PR introduces a performant version of the backward kernel implementation in Triton. It follows the same implementation strategy as the example in upstream Triton but with added functionality.

As expected, it surpasses the existing implementation but falls short to only half of the performance from the example mentioned above. Here's a quick summary of the performance at the time of submitting this PR.
image
Note that the tutorials/06-fused-attention.py assumes Q and KV have the same heads and sequence length, so it only is tested on those examples.

The performance investigation will follow through from another PR later

@jtang10
Copy link
Author

jtang10 commented Jan 31, 2025

The switch between this new kernel bwd_prefill_split.py and bwd_prefill.py is controlled by an environment variable here

USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes')
.

Though we discussed that an argument in a function to toggle it is preferred, passing the triton kernel wrapper into interfaces makes it hard to be controlled by argument imo, so I leave it as-is, but turn on split kernel by default.

flash_attn/flash_attn_interface.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_ref.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/interface_fa.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Outdated Show resolved Hide resolved
@jtang10 jtang10 force-pushed the jingtang/fa_bwd_split branch from ea973eb to 383d8c7 Compare February 3, 2025 04:29
@jtang10
Copy link
Author

jtang10 commented Feb 3, 2025

most of the changes related to removing/changing the existing debug messages is within a single commit and therefore dropped.

@jtang10 jtang10 force-pushed the jingtang/fa_bwd_split branch from 814471e to 7d83cd6 Compare February 3, 2025 23:10
@jtang10
Copy link
Author

jtang10 commented Feb 3, 2025

The problem introduced from 58941ed has been manually reverted from 8436dc7 and fixed in 814471e.

The solution is simple, just do not assume any tensor sharing the same stride and therefore use the strides from the tensor itself. For example, q and dq does not necessarily share the same stride because q is individually potentially padded (padding seems to pack the tensor and make it contiguous first) and loaded from forward() path and dq is indexed from dqkv in the backward() path and discontiguous on the dim of 3 at (b, s, 3, h, d). As a result, assuming q and dq would lead to wrong results.

Another change is to NOT .contiguous() q, k, v at the beginning of triton wrapper. After the change above, this is unnecessary and create additional data movement.

TODO:
I keep the hack from your version removed but didn't patch this update into it. I'm gonna focus on the perf part so I'll leave it to you @micmelesse to have it updated.

@jtang10 jtang10 force-pushed the jingtang/fa_bwd_split branch from 7d83cd6 to ada4bb8 Compare February 3, 2025 23:22
@jtang10 jtang10 force-pushed the jingtang/fa_bwd_split branch 2 times, most recently from ddd07df to bacc596 Compare February 4, 2025 16:39
Copy link
Collaborator

@micmelesse micmelesse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is excellent

@micmelesse micmelesse merged commit 929f0e8 into main_perf Feb 4, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants