Skip to content

Memcpy kernel for flash attention #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 11, 2023
Merged

Memcpy kernel for flash attention #29

merged 16 commits into from
Apr 11, 2023

Conversation

suquark
Copy link
Contributor

@suquark suquark commented Apr 6, 2023

Memcpy kernel for flash attention

num_tokens: 64, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.008 ms
[Throughput] gather_cached_kv: 156.479 GB/s
num_tokens: 128, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.011 ms
[Throughput] gather_cached_kv: 216.171 GB/s
num_tokens: 256, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.032 ms
[Throughput] gather_cached_kv: 152.631 GB/s
num_tokens: 512, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.057 ms
[Throughput] gather_cached_kv: 172.325 GB/s
num_tokens: 1024, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.104 ms
[Throughput] gather_cached_kv: 187.537 GB/s
num_tokens: 2048, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.204 ms
[Throughput] gather_cached_kv: 191.603 GB/s

The performance is pretty good (theoretical optimal throughput is 1.6TB/s for A100-40GB), considering the memory layout is not ideal.

result for unoptimized kernel:

num_tokens: 64, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.010 ms
[Throughput] gather_cached_kv: 125.891 GB/s
num_tokens: 128, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.015 ms
[Throughput] gather_cached_kv: 160.678 GB/s
num_tokens: 256, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.032 ms
[Throughput] gather_cached_kv: 150.732 GB/s
num_tokens: 512, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.060 ms
[Throughput] gather_cached_kv: 162.482 GB/s
num_tokens: 1024, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.108 ms
[Throughput] gather_cached_kv: 180.763 GB/s
num_tokens: 2048, num_heads: 40, head_size: 128, block_size: 8, num_blocks: 1024, dtype: torch.float16
[Latency] gather_cached_kv: 0.206 ms
[Throughput] gather_cached_kv: 189.757 GB/s

the optimized kernel works much better for smaller number of tokens (+20% speedup)

@suquark suquark changed the title Memcpy for flashattn Memcpy kernel for flashattn Apr 6, 2023
@suquark suquark changed the title Memcpy kernel for flashattn Memcpy kernel for flash attention Apr 6, 2023
@suquark
Copy link
Contributor Author

suquark commented Apr 6, 2023

implementation is done. need testing (will do it on Thursday)

the memory saving strategy is orthogonal to this kernel, so I would not include it in this PR

@suquark suquark requested a review from WoosukKwon April 6, 2023 08:57
@suquark suquark force-pushed the memcpy4flashattn branch from 678bb06 to 07e9891 Compare April 8, 2023 20:40
optimize with shared memory

better number of threads

update test

temp disable test

update
@suquark suquark force-pushed the memcpy4flashattn branch from 07e9891 to e21845e Compare April 8, 2023 20:45
@WoosukKwon
Copy link
Collaborator

Hey @suquark thanks for the PR! I have a quick question: have you also measured the performance diff between the two kernels before and after the optimization?

@suquark suquark closed this Apr 11, 2023
@WoosukKwon WoosukKwon reopened this Apr 11, 2023
@suquark
Copy link
Contributor Author

suquark commented Apr 11, 2023

see the PR comment for the optimized kernel performance comparison

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
* optimize

* add benchmark

* add assert

* add test
luo-cheng2021 pushed a commit to luo-cheng2021/vllm that referenced this pull request Apr 17, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
It's faster

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
tianyil1 pushed a commit to tianyil1/vllm that referenced this pull request Jun 5, 2024
fxmarty pushed a commit to fxmarty/vllm-public that referenced this pull request Jun 12, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request Jun 21, 2024
bigPYJ1151 pushed a commit to bigPYJ1151/vllm that referenced this pull request Jul 31, 2024
…ack_acc_bf16

fix linear init impacts on generation
@alixiaodi alixiaodi mentioned this pull request Aug 2, 2024
wuhuikx pushed a commit to wuhuikx/vllm that referenced this pull request Mar 27, 2025
Add official doc index. Move the release content to the right place.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants