Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : revisit using flash attention for prompt processing (a.k.a. prefil) + GPU implementation #3365

Closed
ggerganov opened this issue Sep 27, 2023 · 12 comments
Labels
performance Speed related topics research 🔬

Comments

@ggerganov
Copy link
Owner

Sample Flash Attention usage on the CPU has been demonstrated here:

#778

Although it didn't provide improvement during text generation workloads, it might be beneficial during prompt processing.
Investigate if that is the case and attempt Metal and CUDA implementations

@netrunnereve
Copy link
Contributor

netrunnereve commented Sep 27, 2023

Does the patch in #778 still work with the latest master? That would be the easiest way for us to try this out.

@ggerganov
Copy link
Owner Author

Haven't tested, but I think it should work. This implementation is just for the CPU.
Even if it does not show an advantage, we should still try to implement a GPU version and see how it performs

@fzzylogic
Copy link

Just putting this here fwiw, creators of FlashAttention released FlashDecoding, which can apparently improve inference by up to 8x.

@lapp0
Copy link

lapp0 commented Nov 4, 2023

Published yesterday, FlashDecoding++

"Our extensive results show that FlashDecoding++ achieves an average of 1.37× speedup compared with FlashDecoding"

@BrainSlugs83
Copy link

BrainSlugs83 commented Nov 27, 2023

FYI Flash Attention 2 also exists now: https://arxiv.org/abs/2307.08691 (published in July) whereas #778 is from April... so that PR might be a good starting place, but it's not testing the latest thing. -- Supposedly the new algorithm achieves about a 2x speedup vs the original.

Off topic, but for CPU inferencing speed-ups, this might be applicable?
https://huggingface.co/papers/2311.00502 -- they're claiming 20ms to 30ms per token with Llama 7B on a 4th gen Xeon CPU... which is kind of impressive IMO. -- Even on a 4bit model with my RTX 2070 I'm not getting those speeds out of llama.cpp! -- It looks like they're doing stuff to take advantage of the CPU cache more effectively?

(I mention it because it seems like kind of a sister to Flash Attention / Decoding which are optimized to take advantage of the GPU architecture, and this seems to take advantage of the CPU architecture -- I only did a quick skim though.)

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 29, 2023

FYI Flash Attention 2 also exists now: https://arxiv.org/abs/2307.08691

Flash Attention 2 is oriented to GPU and use tensor cores.

@BrainSlugs83
Copy link

BrainSlugs83 commented Nov 29, 2023

Flash Attention 2 is oriented to GPU and use tensor cores.

Right, so is flash attention 1 though... And Llama.cpp has GPU support via CUDA, does it not?

Flash attention 1 paper:

We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes.

Anyway, for CPU specific optimizations, that "Efficient LLM Inference on CPUs" paper I mentioned is getting much better speeds on much worse hardware than what I'm seeing out of the box with llama.cpp (even my Q2 is slower than their Q4); so it might be worth looking into, right? -- Any objection to me creating a feature request specifically for that? -- Or is it not applicable to llama.cpp's architecture?

@FSSRepo
Copy link
Collaborator

FSSRepo commented Nov 29, 2023

I am trying to create the CUDA kernel, but I am having trouble finding a way to reduce the memory usage as proposed in the paper (Using shared memory, which is the cache or SRAM, but it is too small, less than 10 MB). The implementation in the original repository requires Nvidia CUTLASS (a library that is too heavy with many dependencies), which allows high-level handling of operations. However, here we are working with CUDA at a low level, using custom kernels.

@ggerganov @slaren
I will be working on creating a kernel for flash attention in CUDA, and I will need some feedback. Perhaps someone has noticed something that I am overlooking in the implementation of flash attention. Creating this CUDA kernel may not be very helpful in terms of speed for llama.cpp, but for stable diffusion.cpp, which requires very large multiplications in the self-attention part [4096, 4096, 8] (512MB peak memory) to an image 512x512 and [16384, 16384, 8](8GB peak memory) to an image 1024x1024, it would definitely help a lot in improving memory usage and performance.

The part that is not clear to me is how it manages to store a very long vector in SRAM, for example, in an image of 512x512, it could use 16KB.

@RonanKMcGovern
Copy link

It would be great to add flash encoding and decoding type features. I see them as three things:

  1. Flash attention 2 (avoid having to re-read the full key x query matrices, by doing the attention calculation fully through in pieces).
  2. Paged attention (see vLLM). Avoids fragmentation of memory.
  3. Flash decoding - a bit like flash attention but parallelizing the attention calculation along the sequence dimension for each decoding step.

Right now I believe TGI has 1 & 3, while vLLM has 1 & 2.

1&2 will help a lot with using llama cpp for training on longer context.

@YehowshuaScaled
Copy link

@FSSRepo, are you still working on the CUDA kernels? Do you have a branch? What kinds of memory issues have you been hitting?

Copy link
Contributor

github-actions bot commented Apr 3, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 3, 2024
@ggerganov
Copy link
Owner Author

#5021

@ggerganov ggerganov reopened this Apr 4, 2024
@ggerganov ggerganov removed the stale label Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics research 🔬
Projects
Status: Done
Development

No branches or pull requests

8 participants