Skip to content

Commit 6c44dde

Browse files
authored
Merge pull request #1 from xfanac/tensorboard
Fix bug for tensorboard.
2 parents 95a58b8 + 12b2ceb commit 6c44dde

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def main(
717717
# Only show the progress bar once on each machine.
718718
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
719719
progress_bar.set_description("Steps")
720+
writer = SummaryWriter()
720721

721722
def finetune_unet(batch, train_encoder=False):
722723
nonlocal use_offset_noise
@@ -836,7 +837,6 @@ def finetune_unet(batch, train_encoder=False):
836837

837838
return loss, latents
838839

839-
writer = SummaryWriter()
840840
for epoch in range(first_epoch, num_train_epochs):
841841
train_loss = 0.0
842842

@@ -961,12 +961,11 @@ def finetune_unet(batch, train_encoder=False):
961961

962962
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
963963
accelerator.log({"training_loss": loss.detach().item()}, step=step)
964-
writer.add_scalar('Loss/train', loss.detach().item(), step)
964+
writer.add_scalar('Loss/train', loss.detach().item(), global_step)
965965
progress_bar.set_postfix(**logs)
966966

967967
if global_step >= max_train_steps:
968968
break
969-
writer.close()
970969

971970
# Create the pipeline using the trained modules and save it.
972971
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)