Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP Model loading with accelerate results in crash (OOM) #26157

Closed
3 of 4 tasks
jphme opened this issue Sep 13, 2023 · 2 comments
Closed
3 of 4 tasks

FSDP Model loading with accelerate results in crash (OOM) #26157

jphme opened this issue Sep 13, 2023 · 2 comments

Comments

@jphme
Copy link
Contributor

jphme commented Sep 13, 2023

System Info

  • transformers version: 4.34.0.dev0
  • Platform: Linux-5.4.0-156-generic-x86_64-with-glibc2.35
  • Python version: 3.9.18
  • Huggingface_hub version: 0.17.1
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0.dev0
  • Accelerate config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: y
  • Using distributed or parallel set-up in script?: FSDP

Who can help?

@pacman100

When trying to start a full FT of Llama 7b on an 4*V100s instance (using this config without bf16, also tried other variations e.g. with fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ) with accelerate, CPU ram fills until process termination.

I though that #25107 should have solved this, but whatever I do, can't get it to work. Could the Volta arch be a reason for this?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

accelerate config:

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
use_cpu: false

seq_len 2048, llama2-7b, happens with different datasets, 4* V100s, 173GB ram

Expected behavior

Model loads and finetuning works with FSDP

@pacman100
Copy link
Contributor

pacman100 commented Sep 14, 2023

Hello, can you share the training code? Please make sure that torch distributed process group is already initialized before loading the pretrained model. When using Trainer, make sure the object of TrainingArguments is created before loading the pretrained model as it initializes the torch-distributed process group.

training_arguments = TrainingArguments(
       ...
    )

model = AutoModelForCausalLM.from_pretrained()
...

This is because we want only the main process to have the pretrained model loaded and all other processes to have empty weights. For this to happen, the process group needs to be initialized via torch.distributed.init_process_group which happens when creating an object of TrainingArguments. See the check here needed for RAM efficient FSDP loading

def is_fsdp_enabled():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
)

@jphme
Copy link
Contributor Author

jphme commented Sep 14, 2023

Many thanks, this solved my issue (and had a small configuration issue which overwrote some config flags) - FSDP is now finally working!

@jphme jphme closed this as completed Sep 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants