Skip to content

[FSDP][torch.compile] accelerator.unwrap_model and trainer._save work incorrectly when FSDP + torch.compile #37519

@efsotr

Description

@efsotr

System Info

transformers 4.51.3
accelerate 1.6.0

Who can help?

@zach-huggingface @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

To use torch.compile, you need to either uninstall the kernels library or set the environment variable DISABLE_KERNEL_MAPPING to 1.

train.py

from typing import cast

import torch
from transformers import HfArgumentParser, Trainer, TrainingArguments, LlamaForCausalLM, LlamaConfig

args = HfArgumentParser(TrainingArguments)
training_args = cast(TrainingArguments, args.parse_args_into_dataclasses())[0]
print(training_args, flush=True)

config = LlamaConfig(
    vocab_size=128, 
    hidden_size=128, 
    intermediate_size=128*2,
    num_hidden_layers=2
)
model = LlamaForCausalLM(config).cuda().bfloat16()

train_dataset = [{"input_ids": torch.randint(0, 128, (128,)),
                  "labels": torch.randint(0, 128, (128,))} for i in range(16)]

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=train_dataset,
)

trainer.train()
trainer.save_state()

fsdp.yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
fsdp_config:
    fsdp_sharding_strategy: FULL_SHARD
    fsdp_activation_checkpointing: false
    fsdp_use_orig_params: true
    fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
    fsdp_backward_prefetch_policy: BACKWARD_PRE
    fsdp_offload_params: false
    fsdp_state_dict_type: FULL_STATE_DICT
    fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer,Embedding
mixed_precision: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

launch script

export CUDA_VISIBLE_DEVICES=0,1
export DISABLE_KERNEL_MAPPING=1

OUTPUT_DIR=test_fsdp
mkdir -p $OUTPUT_DIR

OMP_NUM_THREADS=8 accelerate launch --main_process_port 40129 --config_file fsdp.yaml \
     train.py \
    --torch_compile_mode default \
    --do_train \
    --optim adamw_torch_fused \
    --learning_rate 1e-3 \
    --weight_decay 0 \
    --lr_scheduler_type constant_with_warmup \
    --warmup_ratio 0.1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --eval_on_start 0 \
    --eval_strategy epoch \
    --eval_steps 1 \
    --save_strategy epoch \
    --save_only_model 1 \
    --greater_is_better False \
    --logging_strategy steps \
    --logging_steps 1 \
    --include_tokens_per_second \
    --output_dir $OUTPUT_DIR \
    --num_train_epochs 1 \
    --seed 0 \
    --report_to none \
    > $OUTPUT_DIR/training.log 2>&1 

Expected behavior

file test_fsdp/checkpoint-2/config.json exists

run

from safetensors import safe_open
path = "test_fsdp/checkpoint-2/model.safetensors"
file = safe_open(path, framework="pt")
print(file.keys())
lm_head = "lm_head.weight"
if lm_head not in file.keys():
    lm_head += "_orig_mod."
print(file.get_tensor(lm_head).shape)

expected to get

['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.norm.weight']
torch.Size([128, 128])

instead of

['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight']
torch.Size([8128])

If the --eval_strategy epoch in the launch script is changed to --eval_strategy no, then

['_orig_mod.lm_head.weight', '_orig_mod.model.embed_tokens.weight', '_orig_mod.model.layers.0.input_layernorm.weight', '_orig_mod.model.layers.0.mlp.down_proj.weight', '_orig_mod.model.layers.0.mlp.gate_proj.weight', '_orig_mod.model.layers.0.mlp.up_proj.weight', '_orig_mod.model.layers.0.post_attention_layernorm.weight', '_orig_mod.model.layers.0.self_attn.k_proj.weight', '_orig_mod.model.layers.0.self_attn.o_proj.weight', '_orig_mod.model.layers.0.self_attn.q_proj.weight', '_orig_mod.model.layers.0.self_attn.v_proj.weight', '_orig_mod.model.layers.1.input_layernorm.weight', '_orig_mod.model.layers.1.mlp.down_proj.weight', '_orig_mod.model.layers.1.mlp.gate_proj.weight', '_orig_mod.model.layers.1.mlp.up_proj.weight', '_orig_mod.model.layers.1.post_attention_layernorm.weight', '_orig_mod.model.layers.1.self_attn.k_proj.weight', '_orig_mod.model.layers.1.self_attn.o_proj.weight', '_orig_mod.model.layers.1.self_attn.q_proj.weight', '_orig_mod.model.layers.1.self_attn.v_proj.weight', '_orig_mod.model.norm.weight']
torch.Size([128, 128])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions