diff --git a/.pylintrc b/.pylintrc index e94869511..d6f8a5d6c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -333,7 +333,7 @@ indent-string=' ' max-line-length=100 # Maximum number of lines in a module. -max-module-lines=1100 +max-module-lines=1200 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. diff --git a/README.md b/README.md index 26c1da347..7fd8fd5d7 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,11 @@ You can set `output_dir` to a local directory and set `save_model_dir` to COS to In order to achieve the fastest train time, set `save_strategy="no"`, as saving no checkpoints except for the final model will remove intermediate write operations all together. +#### Resuming tuning from checkpoints +If the output directory already contains checkpoints, tuning will automatically resume from the latest checkpoint in the directory specified by the `output_dir` flag. To start tuning from scratch and ignore existing checkpoints, set the `resume_from_checkpoint` flag to False. + +You can also use the resume_from_checkpoint flag to resume tuning from a specific checkpoint by providing the full path to the desired checkpoint as a string. This flag is passed as an argument to the [trainer.train()](https://github.com/huggingface/transformers/blob/db70426854fe7850f2c5834d633aff637f14772e/src/transformers/trainer.py#L1901) function of the SFTTrainer. + ## Tuning Techniques: ### LoRA Tuning Example diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 251f6d6b9..2d55b7de4 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -80,6 +80,214 @@ PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) +def test_resume_training_from_checkpoint(): + """ + Test tuning resumes from the latest checkpoint, creating new checkpoints and the + checkpoints created before resuming tuning is not affected. + """ + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get trainer state of latest checkpoint + init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) + assert init_trainer_state is not None + + # Resume training with higher epoch and same output dir + train_args.num_train_epochs += 5 + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get trainer state of latest checkpoint + final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) + assert final_trainer_state is not None + + assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5 + assert final_trainer_state["global_step"] > init_trainer_state["global_step"] + + # Check if loss of 1st epoch after first tuning is same after + # resuming tuning and not overwritten + assert len(init_trainer_state["log_history"]) > 0 + + init_log_history = init_trainer_state["log_history"][0] + assert init_log_history["epoch"] == 1 + + final_log_history = final_trainer_state["log_history"][0] + assert final_log_history["epoch"] == 1 + + assert init_log_history["loss"] == final_log_history["loss"] + + +def test_resume_training_from_checkpoint_with_flag_true(): + """ + Test tuning resumes from the latest checkpoint when flag is true, + creating new checkpoints and the checkpoints created before resuming + tuning is not affected. + """ + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.resume_from_checkpoint = "True" + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get trainer state of latest checkpoint + init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) + assert init_trainer_state is not None + + # Get Training logs + init_training_logs = _get_training_logs_by_epoch(tempdir) + + # Resume training with higher epoch and same output dir + train_args.num_train_epochs += 5 + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get trainer state of latest checkpoint + final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) + assert final_trainer_state is not None + + assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5 + assert final_trainer_state["global_step"] > init_trainer_state["global_step"] + + final_training_logs = _get_training_logs_by_epoch(tempdir) + + assert ( + init_training_logs[0]["data"]["timestamp"] + == final_training_logs[0]["data"]["timestamp"] + ) + + +def test_resume_training_from_checkpoint_with_flag_false(): + """ + Test when setting resume_from_checkpoint=False that tuning will start from scratch. + """ + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.resume_from_checkpoint = "False" + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get trainer state of latest checkpoint + init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) + assert init_trainer_state is not None + + # Get Training log entry for epoch 1 + init_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1) + assert len(init_training_logs) == 1 + + # Training again with higher epoch and same output dir + train_args.num_train_epochs += 5 + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) + _validate_training(tempdir) + + # Get Training log entry for epoch 1 + final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1) + assert len(final_training_logs) == 2 + + +def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora(): + """ + Test resume checkpoint from a specified checkpoint path for LoRA tuning. + """ + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + lora_config = copy.deepcopy(PEFT_LORA_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config) + _validate_training(tempdir) + + # Get trainer state and checkpoint_path of second last checkpoint + init_trainer_state, checkpoint_path = _get_latest_checkpoint_trainer_state( + tempdir, checkpoint_index=-2 + ) + assert init_trainer_state is not None + + # Resume training with higher epoch and same output dir + train_args.num_train_epochs += 5 + train_args.resume_from_checkpoint = checkpoint_path + sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config) + _validate_training(tempdir) + + # Get total_flos from trainer state of checkpoint_path and check if its same + final_trainer_state = None + trainer_state_file = os.path.join(checkpoint_path, "trainer_state.json") + with open(trainer_state_file, "r", encoding="utf-8") as f: + final_trainer_state = json.load(f) + + assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"] + + +def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = -1): + """ + Get the trainer state from the latest or specified checkpoint directory. + The trainer state is returned along with the path to the checkpoint. + + Args: + dir_path (str): The directory path where checkpoint folders are located. + checkpoint_index (int, optional): The index of the checkpoint to retrieve, + based on the checkpoint number. The default + is -1, which returns the latest checkpoint. + + Returns: + trainer_state: The trainer state loaded from `trainer_state.json` in the + checkpoint directory. + last_checkpoint: The path to the checkpoint directory. + """ + trainer_state = None + last_checkpoint = None + checkpoints = [ + os.path.join(dir_path, d) + for d in os.listdir(dir_path) + if d.startswith("checkpoint") + ] + if checkpoints: + last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[ + checkpoint_index + ] + trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json") + with open(trainer_state_file, "r", encoding="utf-8") as f: + trainer_state = json.load(f) + return trainer_state, last_checkpoint + + +def _get_training_logs_by_epoch(dir_path: str, epoch: int = None): + """ + Load and optionally filter training_logs.jsonl file. + If an epoch number is specified, the function filters the logs + and returns only the entries corresponding to the specified epoch. + + Args: + dir_path (str): The directory path where the `training_logs.jsonl` file is located. + epoch (int, optional): The epoch number to filter logs by. If not specified, + all logs are returned. + + Returns: + list: A list containing the training logs. If `epoch` is specified, + only logs from the specified epoch are returned; otherwise, all logs are returned. + """ + data_list = [] + with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file: + for line in file: + json_data = json.loads(line) + data_list.append(json_data) + + if epoch: + mod_data_list = [] + for value in data_list: + if value["data"]["epoch"] == epoch: + mod_data_list.append(value) + return mod_data_list + return data_list + + def test_run_train_requires_output_dir(): """Check fails when output dir not provided.""" updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2ab8f7de0..da8fa5172 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -35,6 +35,7 @@ LlamaTokenizerFast, TrainerCallback, ) +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_accelerate_available from trl import SFTConfig, SFTTrainer import transformers @@ -215,7 +216,7 @@ def train( ), ) - # add special tokens only when a custom tokenizer is not passed + # Add special tokens only when a custom tokenizer is not passed if not model_args.tokenizer_name_or_path: # TODO: understand if we need to hardcode these here or just use defaults in model if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): @@ -366,7 +367,24 @@ def train( for x in framework.get_callbacks_and_ready_for_train(model, accelerator): trainer.add_callback(x) - trainer.train() + resume_from_checkpoint = None + # Check if resume flag is not passed (None), or if flag is true and + # output_dir has checkpoints then get last checkpoint from output_dir + if ( + training_args.resume_from_checkpoint is None + or training_args.resume_from_checkpoint.lower() == "true" + ): + resume_from_checkpoint = get_last_checkpoint(training_args.output_dir) + else: + # `training_args.resume_from_checkpoint` gives string values + # Check if flag is false OR flag has checkpoint value for resuming tuning + resume_from_checkpoint = ( + training_args.resume_from_checkpoint + if training_args.resume_from_checkpoint.lower() != "false" + else False + ) + + trainer.train(resume_from_checkpoint) return trainer