-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Closed
Labels
Description
System Info
transformersversion: 4.41.1- Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.1
- Safetensors version: 0.4.2
- Accelerate version: 0.30.1
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: FSDP
- mixed_precision: bf16
- use_cpu: False - debug: False - num_processes: 5 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - enable_cpu_affinity: False - fsdp_config: {'fsdp_auto_wrap_policy': 'SIZE_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': False, 'fsdp_min_num_params': 100000000, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 'SHARD_GRAD_OP', 'fsdp_state_dict_type': 'FULL_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False
- tpu_env: []
- dynamo_config: {'dynamo_backend': 'INDUCTOR'} - PyTorch version (GPU?): 2.3.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: FDSP on 5 GPUs in 1 node
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Train a model with FDSP as configured in accelerate configure
First time saved weight OK (doesn't matter when the weight is saved, mid-epoch, first step, end of 100 epochs etc. as long as it's first time)
2nd time saving onwards the weights are magically ~100MB smaller with all the keys BUT no weight in some of them, and wrong shape in others. Causes error when loading:

I have so far tested both SigLIP and OWLv2, both has the same issue. Other models may also. Happens with both safetensor and pytorch.bin. pytorch_model_fdsp.bin is also missing them.
I have set state dict to FULL.
Expected behavior
No issue saving