Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Bark FA2 docs #27400

Merged
merged 5 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions docs/source/en/model_doc/bark.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(device)
```

#### Using CPU offload

As mentioned above, Bark is made up of 4 sub-models, which are called up sequentially during audio generation. In other words, while one sub-model is in use, the other sub-models are idle.

If you're using a CUDA device, a simple solution to benefit from an 80% reduction in memory footprint is to offload the GPU's submodels when they're idle. This operation is called CPU offloading. You can use it with one line of code.
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

```python
model.enable_cpu_offload()
```

Note that 🤗 Accelerate must be installed before using this feature. [Here's how to install it.](https://huggingface.co/docs/accelerate/basic_tutorials/install)

#### Using 🤗 Better Transformer
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

Better Transformer is an 🤗 Optimum feature that performs kernel fusion under the hood. You can gain 20% to 30% in speed with zero performance degradation. It only requires one line of code to export the model to 🤗 Better Transformer:
Expand All @@ -54,33 +66,53 @@ model = model.to_bettertransformer()

Note that 🤗 Optimum must be installed before using this feature. [Here's how to install it.](https://huggingface.co/docs/optimum/installation)

#### Using CPU offload
#### Using Flash Attention 2

As mentioned above, Bark is made up of 4 sub-models, which are called up sequentially during audio generation. In other words, while one sub-model is in use, the other sub-models are idle.
Flash Attention 2 is an even faster, optimized version of the previous optimization.

If you're using a CUDA device, a simple solution to benefit from an 80% reduction in memory footprint is to offload the GPU's submodels when they're idle. This operation is called CPU offloading. You can use it with one line of code.
##### Installation

First, make sure to [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2.
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash Attention 2. Read more about it in the [official documentation](https://github.com/Dao-AILab/flash-attention) of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

##### Usage

To load and run a model using FA2, refer to the snippet below:
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

```python
model.enable_cpu_offload()
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
```

Note that 🤗 Accelerate must be installed before using this feature. [Here's how to install it.](https://huggingface.co/docs/accelerate/basic_tutorials/install)
##### Performance comparison


Flash Attention 2 is also consistently faster than Better Transformer, and its performance improves even more as batch sizes increase, as the following diagram shows.
ylacombe marked this conversation as resolved.
Show resolved Hide resolved

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ylacombe/benchmark-comparison/resolve/main/Bark%20Optimization%20Benchmark.png">
</div>

To put this into perspective, on an NVIDIA A100, you can get 17 times the [throughput](https://huggingface.co/blog/optimizing-bark#throughput) and still be 2 seconds faster than the unoptimized, non-batch version.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really follow where this 17x number comes from? Are we comparing FA2 batched vs un-optimised non-batched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, I'll make it clearer

At batch size 8, on an NVIDIA A100, Flash Attention 2 is also 10% faster than Better Transformer, and at batch size 16, 25%.


#### Combining optimization techniques

You can combine optimization techniques, and use CPU offload, half-precision and 🤗 Better Transformer all at once.
You can combine optimization techniques, and use CPU offload, half-precision and Flash Attention 2 (or 🤗 Better Transformer) all at once.

```python
from transformers import BarkModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# load in fp16
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16).to(device)

# convert to bettertransformer
model = BetterTransformer.transform(model, keep_original_model=False)
# load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)

# enable CPU offload
model.enable_cpu_offload()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 supports inference with Llama, Mistral, and Falcon models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites):

Expand Down
Loading