Skip to content

Commit

Permalink
fix optimizer reset and relora logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Sep 28, 2023
1 parent c9604b3 commit 7bce45b
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,8 @@ def main(args):
else:
raise ValueError(f"Optimizer {args.optimizer} not supported")

_scheduler_steps = args.num_training_steps - update_step
scheduler_start_step = update_step
_scheduler_steps = args.num_training_steps - scheduler_start_step
logger.info(f"Scheduler will run for {_scheduler_steps} update steps")
scheduler = training_utils.get_scheculer(
optimizer=optimizer,
Expand Down Expand Up @@ -871,7 +872,7 @@ def main(args):
or local_step // args.gradient_accumulation >= args.relora
)

if can_reset_relora and update_step % args.relora == 1:
if can_reset_relora and (update_step - scheduler_start_step) % args.relora == 1:
_lora_reset_time = time.time()
logger.info(f"{args.resume_from=}, {local_step=}, {args.relora=}, thresh: {local_step // args.gradient_accumulation}")
logger.info(f"Performing lora reset at update step {update_step}. Current lr is {optimizer.param_groups[0]['lr']}")
Expand All @@ -892,7 +893,7 @@ def main(args):
or local_step // args.gradient_accumulation >= args.cycle_length
)

if can_reset_optimizer and update_step % args.cycle_length == 1:
if can_reset_optimizer and (update_step - scheduler_start_step) % args.cycle_length == 1:
# scheduler should provide a new warmup after the reset
logger.info(f"Performing optimizer reset at update step {update_step}. Current lr is {optimizer.param_groups[0]['lr']}")
n_optimizer_resets += 1
Expand All @@ -907,7 +908,7 @@ def main(args):
)
# ##############################

if can_reset_optimizer and update_step % args.cycle_length == 2:
if can_reset_optimizer and (update_step - scheduler_start_step) % args.cycle_length == 2:
logger.info(f"First step after optimizer reset lr is {optimizer.param_groups[0]['lr']}")

lr = optimizer.param_groups[0]["lr"]
Expand Down

0 comments on commit 7bce45b

Please sign in to comment.