diff --git a/uniform_finetune.py b/uniform_finetune.py index 1be8222..079ef9e 100644 --- a/uniform_finetune.py +++ b/uniform_finetune.py @@ -31,6 +31,7 @@ import argparse from utils.device import get_device_map +from utils.save import SavePeftModelCallback device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) @@ -400,6 +401,7 @@ def generate_and_tokenize_prompt(data_point): ddp_find_unused_parameters=False if ddp else None, ), data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True), + callbacks=[SavePeftModelCallback], ) model.config.use_cache = False diff --git a/utils/save.py b/utils/save.py new file mode 100644 index 0000000..b5928c1 --- /dev/null +++ b/utils/save.py @@ -0,0 +1,9 @@ +import os +from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +class SavePeftModelCallback(TrainerCallback): + def on_save(self,args: TrainingArguments,state: TrainerState,control: TrainerControl,**kwargs,): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_folder) + return control \ No newline at end of file