-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : revisit using flash attention for prompt processing (a.k.a. prefil) + GPU implementation #3365
Comments
Does the patch in #778 still work with the latest master? That would be the easiest way for us to try this out. |
Haven't tested, but I think it should work. This implementation is just for the CPU. |
Just putting this here fwiw, creators of FlashAttention released FlashDecoding, which can apparently improve inference by up to 8x. |
Published yesterday, FlashDecoding++ "Our extensive results show that FlashDecoding++ achieves an average of 1.37× speedup compared with FlashDecoding" |
FYI Flash Attention 2 also exists now: https://arxiv.org/abs/2307.08691 (published in July) whereas #778 is from April... so that PR might be a good starting place, but it's not testing the latest thing. -- Supposedly the new algorithm achieves about a 2x speedup vs the original. Off topic, but for CPU inferencing speed-ups, this might be applicable? (I mention it because it seems like kind of a sister to Flash Attention / Decoding which are optimized to take advantage of the GPU architecture, and this seems to take advantage of the CPU architecture -- I only did a quick skim though.) |
Flash Attention 2 is oriented to GPU and use tensor cores. |
Right, so is flash attention 1 though... And Llama.cpp has GPU support via CUDA, does it not? Flash attention 1 paper:
Anyway, for CPU specific optimizations, that "Efficient LLM Inference on CPUs" paper I mentioned is getting much better speeds on much worse hardware than what I'm seeing out of the box with llama.cpp (even my Q2 is slower than their Q4); so it might be worth looking into, right? -- Any objection to me creating a feature request specifically for that? -- Or is it not applicable to llama.cpp's architecture? |
I am trying to create the CUDA kernel, but I am having trouble finding a way to reduce the memory usage as proposed in the paper (Using shared memory, which is the cache or SRAM, but it is too small, less than 10 MB). The implementation in the original repository requires Nvidia CUTLASS (a library that is too heavy with many dependencies), which allows high-level handling of operations. However, here we are working with CUDA at a low level, using custom kernels. @ggerganov @slaren The part that is not clear to me is how it manages to store a very long vector in SRAM, for example, in an image of 512x512, it could use 16KB. |
It would be great to add flash encoding and decoding type features. I see them as three things:
Right now I believe TGI has 1 & 3, while vLLM has 1 & 2. 1&2 will help a lot with using llama cpp for training on longer context. |
@FSSRepo, are you still working on the CUDA kernels? Do you have a branch? What kinds of memory issues have you been hitting? |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
Sample Flash Attention usage on the CPU has been demonstrated here:
#778
Although it didn't provide improvement during text generation workloads, it might be beneficial during prompt processing.
Investigate if that is the case and attempt Metal and CUDA implementations
The text was updated successfully, but these errors were encountered: