Skip to content

Commit

Permalink
eval before merge and reinit
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Aug 8, 2023
1 parent 04b1d6f commit 3af9c86
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,22 @@ def main(args):
can_reset = args.resume_from is not None \
or (args.relora is not None and local_step * args.gradient_accumulation > args.relora)

# ##############################
# EVALUATION
if update_step % args.eval_every == 0:
logger.info(f"Performing evaluation at step {update_step}")
total_loss, evaluated_on_tokens = evaluate_model(model, eval_loader, device)

if global_rank == 0:
wandb.log({
"final_eval_loss": total_loss,
"final_eval_tokens": evaluated_on_tokens,
},
step=global_step,
)
logger.info(f"Eval loss at step {update_step}: {total_loss}")
# ##############################

# ##############################
# MERGE AND REINIT
if can_reset and update_step % args.relora == 1:
Expand Down Expand Up @@ -824,22 +840,6 @@ def main(args):
if can_reset and update_step % args.relora == 2:
logger.info(f"First step after lora reset lr is {optimizer.param_groups[0]['lr']}")

# ##############################
# EVALUATION
if update_step % args.eval_every == 0:
logger.info(f"Performing evaluation at step {update_step}")
total_loss, evaluated_on_tokens = evaluate_model(model, eval_loader, device)

if global_rank == 0:
wandb.log({
"final_eval_loss": total_loss,
"final_eval_tokens": evaluated_on_tokens,
},
step=global_step,
)
logger.info(f"Eval loss at step {update_step}: {total_loss}")
# ##############################

lr = optimizer.param_groups[0]["lr"]
tokens_in_update = tokens_seen - tokens_seen_before
tokens_seen_before = tokens_seen
Expand Down

0 comments on commit 3af9c86

Please sign in to comment.