Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request to support FlashAttention in cuda attention.cc #1300

Closed
nemoramo opened this issue Jun 16, 2023 · 23 comments
Closed

Request to support FlashAttention in cuda attention.cc #1300

nemoramo opened this issue Jun 16, 2023 · 23 comments
Labels
enhancement New feature or request

Comments

@nemoramo
Copy link

FlashAttention can largely avoid memory usage and speeds up attention even in the process of inference. Any plan to support this implementation:
https://github.com/facebookresearch/xformers/tree/main/xformers/csrc/attention

@guillaumekln guillaumekln added the enhancement New feature or request label Jun 20, 2023
@jgcb00
Copy link
Contributor

jgcb00 commented Jul 18, 2023

Hi,
New implementation was release :
https://tridao.me/publications/flash2/flash2.pdf
With 50% TFLOPS improvement on the forward pass comparing to the old FlashAttention implementation, and massive improvement comparing to the vanilla Attention mecanism
https://github.com/Dao-AILab/flash-attention

@gordicaleksa
Copy link

Hi @guillaumekln! How hard would it be to implement this?

Can you maybe give us some pointers? Most of the popular libraries support v1.0 already

@jgcb00
Copy link
Contributor

jgcb00 commented Jul 21, 2023

Hi,
I think the V2 will be much simpler to implement as it comes with an higher level library and much more compatible GPU.
It might also restrict the gpus on which Ctranslate2 can run or we need to add a special field to request for flash attention

@guillaumekln
Copy link
Collaborator

At this time FlashAttention is mostly useful for training or when processing a long prompt. However, during inference most of the time is usually spent in the iterative decoding where the bottleneck is somewhere else.

It seems the author is still working to optimize this inference part where the input shape is different: Dao-AILab/flash-attention#346 (comment)

Also I can't find an end-to-end benchmark using FlashAttention for inference. Do you know where we can find one? I only see benchmarks for end-to-end training or for the attention module itself.

@MrigankRaman
Copy link

But I believe it will be able to reduce the VRAM usage further. Can we get some support to run memory efficient attention?

@BBC-Esq
Copy link

BBC-Esq commented Nov 30, 2023

Hello all. Just thought I'd post a question about Flash Attention 2 here:

https://github.com/Dao-AILab/flash-attention

Apparently it's making big waves and seems seems very powerful. Does anyone plan on seeing if it's something that could be included?

I reviewed the prior comments and suggest that we change the topic to Flash Attention 2. I know that guillaumkeln is no longer with faster-whisper, but hopefully one of the admins can weight in on this possibly powerful feature to include in ctranslate2!!

@jgcb00
Copy link
Contributor

jgcb00 commented Nov 30, 2023

Hi my thought on this, they are some major pros and some cons :

Pro :

  • Reduce VRAM usage,
  • Flash-decoding improve speed on long sequence generation (Don't know if something similar is already implemented
  • Faster inference

Cons :

  • Introduce a dependency
  • It's not compatible with all GPU, so it will be tricky to work with

Is it possible to have your thoughts on this dev ?
What will be the work required ?

@vince62s
Copy link
Member

vince62s commented Nov 30, 2023

Eventually this will be included, but it is not the same story to include a pip package (and we did include flash2 in OpenNMT-py) and link a cpp package that is moving quite frequently. Of course we don't want to drop the current path for scaled dot attention.
It takes time but new cpp developers are very welcome.

bear in mind that native pytorch is not dead:
https://pytorch.org/blog/accelerating-generative-ai-2/?hss_channel=lcp-78618366
https://forum.opennmt.net/t/opennmt-py-v3-4-3-released-blazing-fast-beam-search-inference/5546

@junchen6072
Copy link

Some other repo claimed flash attention will be helpful to make transcribe much faster: https://github.com/Vaibhavs10/insanely-fast-whisper
To my read, 2x?

@minhthuc2502
Copy link
Collaborator

minhthuc2502 commented Apr 9, 2024

Ctranslate2 supports soon the flash attention 2 following this PR #1651. I will do the release asap. I made some tests and saw an improvement in performance with long prompt. It would run on GPU architecture >= sm80 only as mentioned in the original repo. It would be great if you guys could test it.

@junchen6072
Copy link

thanks, looking forward to test it with faster-whisper!

@cqchangm149016
Copy link

Ctranslate2 supports soon the flash attention 2 following this PR #1651. I will do the release asap. I made some tests and saw an improvement in performance with long prompt. It would run on GPU architecture >= sm80 only as mentioned in the original repo. It would be great if you guys could test it.

thanks, looking forward to test it with faster-whisper!

This is great! Any chance you could provide some tips as to how to test this on faster-whisper?

@minhthuc2502
Copy link
Collaborator

This is great! Any chance you could provide some tips as to how to test this on faster-whisper?

Make sure you have Ampere GPU or newer. You can just set flash_attention=True when loading model to use Flash attention instead of stand MHA.

@AvivSham
Copy link

AvivSham commented Jun 3, 2024

Hi @minhthuc2502,
Do you have a benchmark comparing Faster Whisper with and without Flash Attention?

@minhthuc2502
Copy link
Collaborator

minhthuc2502 commented Jun 4, 2024

Hello, I did not make a benchmark with Faster Whisper, but there is some benchmark for Flash Attention with some LLM models here.

@BBC-Esq
Copy link

BBC-Esq commented Jun 4, 2024

Hi @minhthuc2502, Do you have a benchmark comparing Faster Whisper with and without Flash Attention?

I haven't benched Whisper in relation to flash attention, but my hypothesis is that it will not make much of a difference for a beam size of 1 but that it "might" if the beam size is increased. However, the benefit will likely not be nearly as great as with stereotypical chat models. I deduce this conclusion based on the following:

  • My testing of flash attention indicates a noticeable VRAM savings and speedup for chat models run with ctranslate2 except for Llama2 models (likely due to architectural differences), but that this is most noticeable when you increase the beam size. Thus, FA2 seems to provide improvements "across the board" when you increase the beam size and there's no indication that this wouldn't also be the case for Whisper (as opposed to chat) models.

  • However, my testing was geared towards a "RAG" use case. This scenario involves sending a single question to an LLM for a response, and accompany the question with "contexts" from a corpus, the goal being for the LLM to respond solely based on the provides contexts. The question and the provided contexts, together, constitute the "prompt" for the LLM to process. The "prompt" in my testing was approximately 1000 tokens.

  • In the linked conversation that @minhthuc2502 provided he states that the benefits of FA2 should be greater the longer the "prompt." Since my "prompt" was only ~1000 tokens, and if what @minhthuc2502 says is true, it means that I didn't fully test the benefits of flash attention in ctranslate2...again, my testing was geared towards RAG.

  • In a non-RAG scenario, such as when you converse with a chat LLM in a multi-turn conversation, the entire conversation is sent to the LLM each time...and each time the user's new message or the LLM's response is appended to the chat history and resent to the LLM. This is commonly referred to as "memory" and is different than a single question like my RAG scenario. In a conversation with memory, the chat history can easily increase above 1000 tokens and will often exceed the LLM's context window. Again, I didn't test for a "prompt" above 1000 tokens.

  • With this background, the Whisper models themselves can only process up to 30 seconds of audio in a given chunk. This is an inherent limitation based on how the Whisper models were trained by OpenAI. The VAD (e.g. see the faster-whisper repo) removes silent portions of the 30 second window so as to pack only speech into it, but the 30 second window remains...

  • As such, you won't see the benefit of flash attention with whisper because - unless you can cram way more than 1000 tokens in that 30-second window - you won't see the benefit that @minhthuc2502 mentions based on a longer "prompt" sent to the LLM. However, as I mentioned, this doesn't disturb my findings regarding the benefits of flash attention when increasing the beam size even with smaller 1000 token chunks.

Keep in mind that I just haven't had time to test this. In my testing I try to honestly represent peoples' hard work, but I'm not a programmer by trade and this is a hobby of mine so...take it with a grain of salt. Hope this helps!

@AvivSham
Copy link

AvivSham commented Jun 4, 2024

Hi @BBC-Esq,
thank you for your insights!
Regarding whisper and number of tokens - every 30 sec window is converted to Mel-Spec features which are equal to 30K tokens each with 80 features. Therefore I expected to see some boost when using FA. Additionally, the default beam size for faster whisper is 5.

@minhthuc2502 The reason I was asking about faster whisper FA benchmark is that I do not see any improvement in speed when loading the whisper model with FA.

Here is the code I used to benchmark:

import time
import torch

from faster_whisper import WhisperModel


def benchmark(model):
    times = []

    # Warmup
    for i in range(10):
        segments, _ = model.transcribe(
            "sample_1.wav",
            language="fr",
        )
        segments = list(segments)

    # Benchmark
    for i in range(100):
        segments, _ = model.transcribe(
            "sample_1.wav",
            language="fr"
        )
        past = time.time()
        segments = list(segments)
        torch.cuda.synchronize()
        times.append(time.time() - past)

    times = times[1:]
    print(f"Mean inference time: {sum(times) / len(times)}")
    print(f"\nTIMES: {times}")


if __name__ == '__main__':

    # model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=True)
    model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=False)
    benchmark(model)

The results for the above code snip are (after running it twice independently):
With FA:
Mean inference time: 0.8763072201699922
W/O FA:
Mean inference time: 0.8619994466955011


About the setup:

  • the audio file is ~50 sec long.
  • GPU: a10G
  • ctranslate2 version: 4.2.1

Is this result expected? if not what can be done to make it faster?

@BBC-Esq
Copy link

BBC-Esq commented Jun 4, 2024

@AvivSham If you're asking for my opinion on how to speed things up generally, faster-whisper has a pull request for batch processing that's not yet approved. If you don't want to wait for it you can use the WhisperS2T library, but once the faster-whisper pull request is approved the speed will be comparable to that of WhisperS2T.

But if you're asking how to make it faster with flash attention, based on the assumption that you might not be using flash attention correctly with faster-whisper...afraid I can't really help. @minhthuc2502 might be able to help, but what I've learned is that those kinds of questions are better posted on the faster-whisper repo. Those peeps are more responsible for actually implementing new features provided by ctranslate2. I.e., unless there's a problem with ctranslate2's implementation of flash attention for whisper models IN GENERAL, the issue would be better addressed at faster-whisper.

With that being said, I can confirm that flash attention works for "chat" models so I'd be surprised if there's some kind of core issue with the ctranslate2 library that prevents it from working just with Whisper models...

@BBC-Esq
Copy link

BBC-Esq commented Jun 4, 2024

BTW, when I said "I can't really help" it's not that I don't want to...it's just that I'm tapped out as far as my personal knowledge...Programming is only hobby for me after all. ;-)

@BBC-Esq
Copy link

BBC-Esq commented Jun 4, 2024

@AvivSham You might also test your script using beam sizes 1-5 and see if there's a difference? If there's a noticeable difference between using flash attention and not, you could perhaps eliminate the variable that somehow the flash attention parameter isn't being used at all? At the end of this discussion they do confirm that flash attention can be used...

SYSTRAN/faster-whisper#598

@AvivSham
Copy link

AvivSham commented Jun 6, 2024

Thank you for your attempt to help! 😄 I will post this question directly in the faster-whisper repo while waiting for @minhthuc2502 's response.

@trungkienbkhn
Copy link

For more information, I executed some benchmarks for Faster whisper with FlashAttention in here.

@minhthuc2502
Copy link
Collaborator

minhthuc2502 commented Jun 26, 2024

Thank you for your attempt to help! 😄 I will post this question directly in the faster-whisper repo while waiting for @minhthuc2502 's response.

With recent tests, I posted a benchmarks with FA2, I noticed that with longer sequence length, I can see more obviously the difference between FA2 and standard MHA. Otherwise, in case of faster whisper, the 30 seconds audio chunk will be converted to an encoder's input with the shape (1,80,3000), see here. The sequence length is quite small to get the benefit of FA2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests