Skip to content

Using accelerate launch FDSP cause weight saved after 2nd time onwards to be incomplete #31034

@aliencaocao

Description

@aliencaocao

System Info

  • transformers version: 4.41.1
  • Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.2
  • Accelerate version: 0.30.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: FSDP
    - mixed_precision: bf16
    - use_cpu: False - debug: False - num_processes: 5 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - enable_cpu_affinity: False - fsdp_config: {'fsdp_auto_wrap_policy': 'SIZE_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': False, 'fsdp_min_num_params': 100000000, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 'SHARD_GRAD_OP', 'fsdp_state_dict_type': 'FULL_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False
    - tpu_env: []
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: FDSP on 5 GPUs in 1 node

Who can help?

@pacman100 @muellerz

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Train a model with FDSP as configured in accelerate configure
First time saved weight OK (doesn't matter when the weight is saved, mid-epoch, first step, end of 100 epochs etc. as long as it's first time)
2nd time saving onwards the weights are magically ~100MB smaller with all the keys BUT no weight in some of them, and wrong shape in others. Causes error when loading:
WhatsApp 图像2024-05-26于14 29 13_ca42656f

I have so far tested both SigLIP and OWLv2, both has the same issue. Other models may also. Happens with both safetensor and pytorch.bin. pytorch_model_fdsp.bin is also missing them.
I have set state dict to FULL.

Expected behavior

No issue saving

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions