Skip to content

Commit

Permalink
fix: bugs in lora model save
Browse files Browse the repository at this point in the history
  • Loading branch information
dkqkxx committed May 20, 2023
1 parent 6195a81 commit 0f4a1d5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions uniform_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import argparse
from utils.device import get_device_map
from utils.save import SavePeftModelCallback

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
Expand Down Expand Up @@ -400,6 +401,7 @@ def generate_and_tokenize_prompt(data_point):
ddp_find_unused_parameters=False if ddp else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True),
callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False

Expand Down
9 changes: 9 additions & 0 deletions utils/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

class SavePeftModelCallback(TrainerCallback):
def on_save(self,args: TrainingArguments,state: TrainerState,control: TrainerControl,**kwargs,):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_folder)
return control

0 comments on commit 0f4a1d5

Please sign in to comment.