diff --git a/torchrun_main.py b/torchrun_main.py index f33746c..3ec73cc 100644 --- a/torchrun_main.py +++ b/torchrun_main.py @@ -795,6 +795,22 @@ def main(args): can_reset = args.resume_from is not None \ or (args.relora is not None and local_step * args.gradient_accumulation > args.relora) + # ############################## + # EVALUATION + if update_step % args.eval_every == 0: + logger.info(f"Performing evaluation at step {update_step}") + total_loss, evaluated_on_tokens = evaluate_model(model, eval_loader, device) + + if global_rank == 0: + wandb.log({ + "final_eval_loss": total_loss, + "final_eval_tokens": evaluated_on_tokens, + }, + step=global_step, + ) + logger.info(f"Eval loss at step {update_step}: {total_loss}") + # ############################## + # ############################## # MERGE AND REINIT if can_reset and update_step % args.relora == 1: @@ -824,22 +840,6 @@ def main(args): if can_reset and update_step % args.relora == 2: logger.info(f"First step after lora reset lr is {optimizer.param_groups[0]['lr']}") - # ############################## - # EVALUATION - if update_step % args.eval_every == 0: - logger.info(f"Performing evaluation at step {update_step}") - total_loss, evaluated_on_tokens = evaluate_model(model, eval_loader, device) - - if global_rank == 0: - wandb.log({ - "final_eval_loss": total_loss, - "final_eval_tokens": evaluated_on_tokens, - }, - step=global_step, - ) - logger.info(f"Eval loss at step {update_step}: {total_loss}") - # ############################## - lr = optimizer.param_groups[0]["lr"] tokens_in_update = tokens_seen - tokens_seen_before tokens_seen_before = tokens_seen