-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
TL;DR: the patch below makes multi-GPU inference 5x faster.
I noticed that text-generation is significantly slower on multi-GPU vs. single-GPU. Some results (using llama models and utilizing the full 2048 context window, I also tested with GPT-J and the results are similar):
model size | quantization | GPU configuration | seconds/token | total peak VRAM usage |
---|---|---|---|---|
7B | 16 | 1x4090 | 0.025 | 14996 MiB |
7B | 16 | 2x4090 | 0.22 (8x slower) | 16484 MiB |
7B w/patch | 16 | 2x4090 | 0.026 | 15891 MiB |
13B | 8 w/bnb | 1x4090 | 0.067 | 16076 MiB |
13B | 8 w/bnb | 2x4090 | 0.34 (5x slower) | 18087 MiB |
13B w/patch | 8 w/bnb | 2x4090 | 0.070 | 16883 MiB |
13B | 4 w/gptq | 1x4090 | 0.036 | 10219 MB |
13B | 4 w/gptq | 2x4090 | 0.228 (6x slower) | 11091 MB |
13B w/patch | 4 w/gptq | 2x4090 | 0.038 | 10336 MB |
This makes using large models prohibitively slow: running 65b model on 4x3090 results in more than 4 seconds/token, which is quite unusable for interactive applications.
After some profiling and debugging, I narrowed this down to a simple problem:
- accelerate wraps the
LlamaModel
and moves the HUGEpast_key_values
tensor to theexecution_device
ofLlamaModel
(which is GPU 0) at the beginning of the inference - it also wraps each
LlamaDecoderLayer
and move the samepast_key_values
(which stays constant during the entire inference pass) from GPU 0 to the execution device for each layer — repeatedly moving it between GPUs for every layer that is not on GPU 0.
This unnecessary repeated moving consumes up to 85% of the inference time. Furthermore, because it makes a copy of the past_key_values
on each GPU, it significantly increases VRAM usage (although I didn't measure the exact number).
I'm not very familiar with the accelerate code base to fix the root cause, but here's a simple patch that solves this problem. It keeps past_key_values
sharded across GPUs — so it is never moved between GPUs (saving execution time) and it's VRAM usage is split across GPUs (saving VRAM).
- save this as llama_accelerate_path.py:
from accelerate.hooks import ModelHook, AlignDevicesHook, add_hook_to_module
from accelerate.utils import find_device, send_to_device
from typing import Mapping
def send_to_device_except(data, device, non_blocking=False, skip_keys=()):
if isinstance(data, Mapping):
return type(data)({
k: v if k in skip_keys else send_to_device(v, device, non_blocking)
for k, v in data.items()
})
else:
return send_to_device(data, self.input_device, non_blocking)
class AlignLogitsHook(AlignDevicesHook):
def pre_forward(self, module, *args, **kwargs):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
return (
send_to_device(args, self.execution_device),
send_to_device_except(kwargs, self.execution_device, skip_keys=("past_key_values",)),
)
def post_forward(self, module, output):
if self.io_same_device and self.input_device is not None:
output = send_to_device_except(output, self.input_device, skip_keys=("past_key_values",))
return output
def apply_to_model(model):
hook = AlignLogitsHook(execution_device=model._hf_hook.execution_device, io_same_device=True)
add_hook_to_module(model, hook)
- Add this to the model loading code:
# model loading with device_map="auto"
# model = AutoModelForCausalLM.from_pretrained(...., device_map="auto").eval()
# apply the patch (this works for llama models only):
import llama_accelerate_path
model = llama_accelerate_path.apply_to_model(model)
NOTE: there is another redundancy: attention_mask
and some other tensors are also copied from GPU 0 to GPU n for every layer's execution. It can be solved in a similar way, but it's less of a problem because attention_mas
k is much smaller than past_key_values
.
It would be great if a more universal version of this fix could be merged to the accelerate mainline.