Skip to content

Commit

Permalink
Use fsdp api for save save (lm-sys#2390)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 10, 2023
1 parent 6af0a7c commit 9b3147e
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions fastchat/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ def rank0_print(*args):
print(*args)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def trainer_save_model_safe(trainer: transformers.Trainer):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
):
trainer.save_model()


def preprocess(
Expand Down Expand Up @@ -279,9 +281,11 @@ def train():
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()

# Save model
model.config.use_cache = True
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
trainer_save_model_safe(trainer)


if __name__ == "__main__":
Expand Down

0 comments on commit 9b3147e

Please sign in to comment.