Skip to content

Commit

Permalink
add delete old checkpoints
Browse files Browse the repository at this point in the history
remove --train_ln because we never changed it
  • Loading branch information
Guitaricet committed Sep 2, 2023
1 parent fbe45d5 commit ff2cc53
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 0 additions & 4 deletions peft_pretraining/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def check_args_torchrun_main(args):
if args.batch_size is None:
raise ValueError("batch_size must be specified")

if not args.train_ln:
logger.error("Are you sure? Not training LN is a bad idea.")
raise ValueError("Are you sure? Not training LN is a bad idea.")

if args.tags is not None:
args.tags = args.tags.split(",")

Expand Down
14 changes: 14 additions & 0 deletions peft_pretraining/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,17 @@ def check_lr_and_alert(optimizer, max_lr):
text=alert_message,
level=wandb.AlertLevel.WARN,
)

def delete_old_checkpoints(save_dir, keep):
if keep is None:
return

checkpoints = [d for d in os.listdir(save_dir) if d.startswith(f"model_")]
if len(checkpoints) <= keep:
return

checkpoints = sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))
for checkpoint in checkpoints[:-keep]:
checkpoint_path = os.path.join(save_dir, checkpoint)
logger.info(f"Deleting checkpoint {checkpoint_path}")
os.system(f"rm -rf {checkpoint_path}")
12 changes: 11 additions & 1 deletion torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def parse_args(args=None):
help=("Keep original model parameters even if relora is None. "
"Useful for making sure that full-LoRa model is equivalent to model+LoRa."))

parser.add_argument("--train_ln", default=True, action="store_true")
parser.add_argument("--optimizer", default="Adam", help="Could be adam (for AdamW) or adam_zero for ZeroRedundancyOptimizer(AdamW)")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_restarts"])
Expand All @@ -114,6 +113,8 @@ def parse_args(args=None):
"You can use M and B suffixes, e.g. 100M or 1B.")
parser.add_argument("--save_every", type=int, default=10_000)
parser.add_argument("--save_dir", type=str, default=None)
parser.add_argument("--keep_checkpoints", type=int, default=None,
help="Number of checkpoints to keep. By default, keep all checkpoints.")
parser.add_argument("--tags", type=str, default=None)
parser.add_argument("--dtype", type=str, default="bfloat16" if torch.cuda.is_bf16_supported() else "float32")
parser.add_argument("--workers", type=int, default=8)
Expand Down Expand Up @@ -606,6 +607,13 @@ def main(args):
trainable_params = [p for p in model.parameters() if p.requires_grad]
lora_params = [p for n, p in model.named_parameters() if p.requires_grad and "lora_" in n]
trainable_params_names = [name for name, p in model.named_parameters() if p.requires_grad]
non_trainable_params_names = [name for name, p in model.named_parameters() if not p.requires_grad]

logger.info("*" * 40)
logger.info("Non-trainable paramters:")
for name in non_trainable_params_names:
logger.info(f"{name:40}")
logger.info("*" * 40)

if args.use_peft and len(lora_params) == 0:
raise ValueError("No LoRA parameters found")
Expand Down Expand Up @@ -816,6 +824,8 @@ def main(args):
distributed_type=args.distributed_type,
save_dir=current_model_directory,
)
if args.keep_checkpoints is not None:
training_utils.delete_old_checkpoints(args.save_dir, keep=args.keep_checkpoints)

# ##############################
# EVALUATION
Expand Down

0 comments on commit ff2cc53

Please sign in to comment.