Skip to content

Commit

Permalink
Merge pull request CNChTu#40 from mlbv/Stable
Browse files Browse the repository at this point in the history
Update solver.py
  • Loading branch information
CNChTu authored Aug 4, 2023
2 parents 956b981 + 2a8dd93 commit ee37e3f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion diffusion/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade

# run
num_batches = len(loader_train)
start_epoch = initial_global_step // num_batches
model.train()
saver.log_info('======= start training =======')
scaler = GradScaler()
Expand All @@ -112,7 +113,7 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
dtype = torch.bfloat16
else:
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
for epoch in range(args.train.epochs):
for epoch in range(start_epoch, args.train.epochs):
for batch_idx, data in enumerate(loader_train):
saver.global_step_increment()
optimizer.zero_grad()
Expand Down

0 comments on commit ee37e3f

Please sign in to comment.