Skip to content

5x faster text generation on multi-GPU setups (+ lower VRAM consumption) #1394

@emvw7yf

Description

@emvw7yf

TL;DR: the patch below makes multi-GPU inference 5x faster.

I noticed that text-generation is significantly slower on multi-GPU vs. single-GPU. Some results (using llama models and utilizing the full 2048 context window, I also tested with GPT-J and the results are similar):

model size quantization GPU configuration seconds/token total peak VRAM usage
7B 16 1x4090 0.025 14996 MiB
7B 16 2x4090 0.22 (8x slower) 16484 MiB
7B w/patch 16 2x4090 0.026 15891 MiB
13B 8 w/bnb 1x4090 0.067 16076 MiB
13B 8 w/bnb 2x4090 0.34 (5x slower) 18087 MiB
13B w/patch 8 w/bnb 2x4090 0.070 16883 MiB
13B 4 w/gptq 1x4090 0.036 10219 MB
13B 4 w/gptq 2x4090 0.228 (6x slower) 11091 MB
13B w/patch 4 w/gptq 2x4090 0.038 10336 MB

This makes using large models prohibitively slow: running 65b model on 4x3090 results in more than 4 seconds/token, which is quite unusable for interactive applications.

After some profiling and debugging, I narrowed this down to a simple problem:

  1. accelerate wraps the LlamaModel and moves the HUGE past_key_values tensor to the execution_device of LlamaModel (which is GPU 0) at the beginning of the inference
  2. it also wraps each LlamaDecoderLayer and move the same past_key_values (which stays constant during the entire inference pass) from GPU 0 to the execution device for each layer — repeatedly moving it between GPUs for every layer that is not on GPU 0.

This unnecessary repeated moving consumes up to 85% of the inference time. Furthermore, because it makes a copy of the past_key_values on each GPU, it significantly increases VRAM usage (although I didn't measure the exact number).

I'm not very familiar with the accelerate code base to fix the root cause, but here's a simple patch that solves this problem. It keeps past_key_values sharded across GPUs — so it is never moved between GPUs (saving execution time) and it's VRAM usage is split across GPUs (saving VRAM).

  1. save this as llama_accelerate_path.py:
from accelerate.hooks import ModelHook, AlignDevicesHook, add_hook_to_module
from accelerate.utils import find_device, send_to_device
from typing import Mapping


def send_to_device_except(data, device, non_blocking=False, skip_keys=()):
    if isinstance(data, Mapping):
        return type(data)({
                k: v if k in skip_keys else send_to_device(v, device, non_blocking)
                for k, v in data.items()
        })
    else:
        return send_to_device(data, self.input_device, non_blocking)


class AlignLogitsHook(AlignDevicesHook):
    def pre_forward(self, module, *args, **kwargs):
        if self.io_same_device:
            self.input_device = find_device([args, kwargs])

        return (
            send_to_device(args, self.execution_device),
            send_to_device_except(kwargs, self.execution_device, skip_keys=("past_key_values",)),
        )

    def post_forward(self, module, output):
        if self.io_same_device and self.input_device is not None:
            output = send_to_device_except(output, self.input_device, skip_keys=("past_key_values",))
        return output


def apply_to_model(model):
    hook = AlignLogitsHook(execution_device=model._hf_hook.execution_device, io_same_device=True)
    add_hook_to_module(model, hook)
  1. Add this to the model loading code:
  # model loading with device_map="auto"
  # model = AutoModelForCausalLM.from_pretrained(...., device_map="auto").eval()

  # apply the patch (this works for llama models only):
  import llama_accelerate_path
  model = llama_accelerate_path.apply_to_model(model)

NOTE: there is another redundancy: attention_mask and some other tensors are also copied from GPU 0 to GPU n for every layer's execution. It can be solved in a similar way, but it's less of a problem because attention_mask is much smaller than past_key_values.

It would be great if a more universal version of this fix could be merged to the accelerate mainline.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions