-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Closed
Description
System Info
transformersversion: 4.44.0- Platform: Linux-6.5.0-15-generic-x86_64-with-glibc2.35
- Python version: 3.10.14
- Huggingface_hub version: 0.24.5
- Safetensors version: 0.4.4
- Accelerate version: 0.33.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.0.dev20240812+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A100 80GB PCIe
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Currently setting the cache type to static model.generation_config.cache_implementation ="static" or using StaticCache breaks with multi-gpu. It throws the following error, probably because the cache is not placed on the right device on some layers:
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:640, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
637 if past_key_value is not None:
638 # sin and cos are specific to RoPE models; cache_position needed for the static cache
639 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 640 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
642 key_states = repeat_kv(key_states, self.num_key_value_groups)
643 value_states = repeat_kv(value_states, self.num_key_value_groups)
File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1083, in StaticCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1080 try:
1081 # If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
1082 cache_position.to(device=k_out.device)
-> 1083 k_out.index_copy_(2, cache_position, key_states)
1084 v_out.index_copy_(2, cache_position, value_states)
1085 except NotImplementedError:
1086 # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA_index_copy_)I can reproduce this with Llama3 70B on 2xA100. The dynamic cache version is working fine and the static cache on a single GPU with a smaller model (same model but quantized) works fine.
Expected behavior
Static cache generation should work with multi-gpu runtime.