Skip to content

Static Cache is broken with multi-gpu inference  #32624

@mobicham

Description

@mobicham

System Info

  • transformers version: 4.44.0
  • Platform: Linux-6.5.0-15-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0.dev20240812+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Currently setting the cache type to static model.generation_config.cache_implementation ="static" or using StaticCache breaks with multi-gpu. It throws the following error, probably because the cache is not placed on the right device on some layers:

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:640, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    637 if past_key_value is not None:
    638     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    639     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 640     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    642 key_states = repeat_kv(key_states, self.num_key_value_groups)
    643 value_states = repeat_kv(value_states, self.num_key_value_groups)

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1083, in StaticCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1080 try:
   1081     # If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
   1082     cache_position.to(device=k_out.device)
-> 1083     k_out.index_copy_(2, cache_position, key_states)
   1084     v_out.index_copy_(2, cache_position, value_states)
   1085 except NotImplementedError:
   1086     # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA_index_copy_)

I can reproduce this with Llama3 70B on 2xA100. The dynamic cache version is working fine and the static cache on a single GPU with a smaller model (same model but quantized) works fine.

Expected behavior

Static cache generation should work with multi-gpu runtime.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions