Skip to content

Commit

Permalink
fix and rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
CNChTu authored Jan 14, 2024
1 parent 73dd9d9 commit b748138
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 367 deletions.
22 changes: 5 additions & 17 deletions diffusion/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,11 @@ def test(args, model, vocoder, loader_test, saver):
gt_spec=data['mel'],
infer=False,
k_step=args.model.k_step_max,
spk_emb=data['spk_emb'],
use_vae=(args.vocoder.type == 'hifivaegan')
)
spk_emb=data['spk_emb'])
test_loss += loss.item()

# log mel
if args.vocoder.type != 'hifivaegan':
saver.log_spec(data['name'][0], data['mel'], mel)
saver.log_spec(data['name'][0], data['mel'], mel)

# log audio
path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0])
Expand Down Expand Up @@ -101,10 +98,6 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
params_count = utils.get_network_paras_amount({'model': model})
saver.log_info('--- model size ---')
saver.log_info(params_count)
if args.vocoder.type == 'hifivaegan':
use_vae = True
else:
use_vae = False

# run
num_batches = len(loader_train)
Expand Down Expand Up @@ -134,21 +127,16 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade
if dtype == torch.float32:
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
aug_shift=data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=args.model.k_step_max,
spk_emb=data['spk_emb'], use_vae=use_vae)
spk_emb=data['spk_emb'])
else:
with autocast(device_type=args.device, dtype=dtype):
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
aug_shift=data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=args.model.k_step_max,
spk_emb=data['spk_emb'], use_vae=use_vae)
spk_emb=data['spk_emb'])

# handle nan loss
if torch.isnan(loss):
#raise ValueError(' [x] nan loss ')
# 如果是nan,则跳过这个batch,并清理以防止内存泄漏
print(' [x] nan loss ')
optimizer.zero_grad()
del loss
continue
raise ValueError(' [x] nan loss ')
else:
# backpropagate
if dtype == torch.float32:
Expand Down
Loading

0 comments on commit b748138

Please sign in to comment.