@@ -717,6 +717,7 @@ def main(
717717 # Only show the progress bar once on each machine.
718718 progress_bar = tqdm (range (global_step , max_train_steps ), disable = not accelerator .is_local_main_process )
719719 progress_bar .set_description ("Steps" )
720+ writer = SummaryWriter ()
720721
721722 def finetune_unet (batch , train_encoder = False ):
722723 nonlocal use_offset_noise
@@ -836,7 +837,6 @@ def finetune_unet(batch, train_encoder=False):
836837
837838 return loss , latents
838839
839- writer = SummaryWriter ()
840840 for epoch in range (first_epoch , num_train_epochs ):
841841 train_loss = 0.0
842842
@@ -961,12 +961,11 @@ def finetune_unet(batch, train_encoder=False):
961961
962962 logs = {"step_loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
963963 accelerator .log ({"training_loss" : loss .detach ().item ()}, step = step )
964- writer .add_scalar ('Loss/train' , loss .detach ().item (), step )
964+ writer .add_scalar ('Loss/train' , loss .detach ().item (), global_step )
965965 progress_bar .set_postfix (** logs )
966966
967967 if global_step >= max_train_steps :
968968 break
969- writer .close ()
970969
971970 # Create the pipeline using the trained modules and save it.
972971 accelerator .wait_for_everyone ()
0 commit comments