diff --git a/diffusion/solver.py b/diffusion/solver.py index f222021..9eb44c9 100644 --- a/diffusion/solver.py +++ b/diffusion/solver.py @@ -66,14 +66,11 @@ def test(args, model, vocoder, loader_test, saver): gt_spec=data['mel'], infer=False, k_step=args.model.k_step_max, - spk_emb=data['spk_emb'], - use_vae=(args.vocoder.type == 'hifivaegan') - ) + spk_emb=data['spk_emb']) test_loss += loss.item() # log mel - if args.vocoder.type != 'hifivaegan': - saver.log_spec(data['name'][0], data['mel'], mel) + saver.log_spec(data['name'][0], data['mel'], mel) # log audio path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0]) @@ -101,10 +98,6 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade params_count = utils.get_network_paras_amount({'model': model}) saver.log_info('--- model size ---') saver.log_info(params_count) - if args.vocoder.type == 'hifivaegan': - use_vae = True - else: - use_vae = False # run num_batches = len(loader_train) @@ -134,21 +127,16 @@ def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loade if dtype == torch.float32: loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], aug_shift=data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=args.model.k_step_max, - spk_emb=data['spk_emb'], use_vae=use_vae) + spk_emb=data['spk_emb']) else: with autocast(device_type=args.device, dtype=dtype): loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], aug_shift=data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=args.model.k_step_max, - spk_emb=data['spk_emb'], use_vae=use_vae) + spk_emb=data['spk_emb']) # handle nan loss if torch.isnan(loss): - #raise ValueError(' [x] nan loss ') - # 如果是nan,则跳过这个batch,并清理以防止内存泄漏 - print(' [x] nan loss ') - optimizer.zero_grad() - del loss - continue + raise ValueError(' [x] nan loss ') else: # backpropagate if dtype == torch.float32: diff --git a/diffusion/unit2mel.py b/diffusion/unit2mel.py index 8eb5d21..a7d729a 100644 --- a/diffusion/unit2mel.py +++ b/diffusion/unit2mel.py @@ -5,14 +5,10 @@ import numpy as np import torch.nn.functional as F from torch.nn.utils import weight_norm -import random from .diffusion import GaussianDiffusion from .wavenet import WaveNet -from .convnext import ConvNext from .vocoder import Vocoder from .naive.naive import Unit2MelNaive -from .naive_v2.naive_v2 import Unit2MelNaiveV2 -from .naive_v2.naive_v2_diff import NaiveV2Diff class DotDict(dict): @@ -24,18 +20,6 @@ def __getattr__(*args): __delattr__ = dict.__delitem__ -def get_z(stack_tensor, mean_only=False): - # stack_tensor: [B x N x D x 2] - # sample z, or mean only - m = stack_tensor.transpose(-1, 0)[:1].transpose(-1, 0).squeeze(-1) - logs = stack_tensor.transpose(-1, 0)[1:].transpose(-1, 0).squeeze(-1) - if mean_only: - z = m # mean only - else: - z = m + torch.randn_like(m) * torch.exp(logs) # sample z - return z # [B x N x D] - - def load_model_vocoder( model_path, device='cpu', @@ -62,16 +46,13 @@ def load_model_vocoder( return model, vocoder, args -def load_model_vocoder_from_combo(combo_model_path, device='cpu', loaded_vocoder=None): +def load_model_vocoder_from_combo(combo_model_path, device='cpu'): read_dict = torch.load(combo_model_path, map_location=torch.device(device)) # args diff_args = DotDict(read_dict["diff_config_dict"]) naive_args = DotDict(read_dict["naive_config_dict"]) # vocoder - if loaded_vocoder is None: - vocoder = Vocoder(diff_args.vocoder.type, diff_args.vocoder.ckpt, device=device) - else: - vocoder = loaded_vocoder + vocoder = Vocoder(diff_args.vocoder.type, diff_args.vocoder.ckpt, device=device) # diff_model print(' [Loading] ' + combo_model_path) @@ -91,43 +72,26 @@ def load_model_vocoder_from_combo(combo_model_path, device='cpu', loaded_vocoder def load_svc_model(args, vocoder_dimension): if args.model.type == 'Diffusion': model = Unit2Mel( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - vocoder_dimension, - args.model.n_layers, - args.model.n_chans, - args.model.n_hidden, - use_speaker_encoder=args.model.use_speaker_encoder, - speaker_encoder_out_channels=args.data.speaker_encoder_out_channels) - - elif args.model.type == 'DiffusionV2': - model = Unit2MelV2( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - vocoder_dimension, - args.model.n_hidden, - use_speaker_encoder=args.model.use_speaker_encoder, - speaker_encoder_out_channels=args.data.speaker_encoder_out_channels, - z_rate=args.model.z_rate, - mean_only=args.model.mean_only, - max_beta=args.model.max_beta, - spec_min=args.model.spec_min, - spec_max=args.model.spec_max, - denoise_fn=args.model.denoise_fn, - mask_cond_ratio=args.model.mask_cond_ratio) + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + vocoder_dimension, + args.model.n_layers, + args.model.n_chans, + args.model.n_hidden, + use_speaker_encoder=args.model.use_speaker_encoder, + speaker_encoder_out_channels=args.data.speaker_encoder_out_channels) elif args.model.type == 'Naive': model = Unit2MelNaive( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - vocoder_dimension, - args.model.n_layers, - args.model.n_chans, - use_speaker_encoder=args.model.use_speaker_encoder, - speaker_encoder_out_channels=args.data.speaker_encoder_out_channels) + args.data.encoder_out_channels, + args.model.n_spk, + args.model.use_pitch_aug, + vocoder_dimension, + args.model.n_layers, + args.model.n_chans, + use_speaker_encoder=args.model.use_speaker_encoder, + speaker_encoder_out_channels=args.data.speaker_encoder_out_channels) elif args.model.type == 'NaiveFS': model = Unit2MelNaive( @@ -141,222 +105,12 @@ def load_svc_model(args, vocoder_dimension): speaker_encoder_out_channels=args.data.speaker_encoder_out_channels, use_full_siren=True, l2reg_loss=args.model.l2_reg_loss) - - elif args.model.type == 'NaiveV2': - model = Unit2MelNaiveV2( - args.data.encoder_out_channels, - args.model.n_spk, - args.model.use_pitch_aug, - vocoder_dimension, - args.model.n_layers, - args.model.n_chans, - use_speaker_encoder=args.model.use_speaker_encoder, - speaker_encoder_out_channels=args.data.speaker_encoder_out_channels) - else: - raise TypeError(" [X] Unknow model") + raise ("Unknow model") return model -class Unit2MelV2(nn.Module): - def __init__( - self, - input_channel, - n_spk, - use_pitch_aug=False, - out_dims=128, - n_hidden=256, - use_speaker_encoder=False, - speaker_encoder_out_channels=256, - z_rate=None, - mean_only=False, - max_beta=0.02, - spec_min=-12, - spec_max=2, - denoise_fn=None, - mask_cond_ratio=None, - ): - super().__init__() - if mask_cond_ratio is not None: - mask_cond_ratio = float(mask_cond_ratio) if (str(mask_cond_ratio) != 'NOTUSE') else None - if mask_cond_ratio > 0: - self.mask_cond_ratio = mask_cond_ratio - else: - self.mask_cond_ratio = None - else: - self.mask_cond_ratio = None - - if denoise_fn is None: - # catch None - denoise_fn = {'type': 'WaveNet', - 'wn_layers': 20, - 'wn_chans': 384, - 'wn_dilation': 1, - 'wn_kernel': 3, - 'wn_tf_use': False, - 'wn_tf_rf': False, - 'wn_tf_n_layers': 2, - 'wn_tf_n_head': 4} - denoise_fn = DotDict(denoise_fn) - - if denoise_fn.type == 'WaveNet': - # catch None - self.wn_layers = denoise_fn.wn_layers if (denoise_fn.wn_layers is not None) else 20 - self.wn_chans = denoise_fn.wn_chans if (denoise_fn.wn_chans is not None) else 384 - self.wn_dilation = denoise_fn.wn_dilation if (denoise_fn.wn_dilation is not None) else 1 - self.wn_kernel = denoise_fn.wn_kernel if (denoise_fn.wn_kernel is not None) else 3 - self.wn_tf_use = denoise_fn.wn_tf_use if (denoise_fn.wn_tf_use is not None) else False - self.wn_tf_rf = denoise_fn.wn_tf_rf if (denoise_fn.wn_tf_rf is not None) else False - self.wn_tf_n_layers = denoise_fn.wn_tf_n_layers if (denoise_fn.wn_tf_n_layers is not None) else 2 - self.wn_tf_n_head = denoise_fn.wn_tf_n_head if (denoise_fn.wn_tf_n_head is not None) else 4 - - # init wavenet denoiser - denoiser = WaveNet(out_dims, self.wn_layers, self.wn_chans, n_hidden, self.wn_dilation, self.wn_kernel, - self.wn_tf_use, self.wn_tf_rf, self.wn_tf_n_layers, self.wn_tf_n_head, self.dwconv) - - elif denoise_fn.type == 'ConvNext': - # catch None - self.cn_layers = denoise_fn.cn_layers if (denoise_fn.cn_layers is not None) else 20 - self.cn_chans = denoise_fn.cn_chans if (denoise_fn.cn_chans is not None) else 384 - self.cn_dilation_cycle = denoise_fn.cn_dilation_cycle if (denoise_fn.cn_dilation_cycle is not None) else 4 - self.mlp_factor = denoise_fn.mlp_factor if (denoise_fn.mlp_factor is not None) else 4 - self.gradient_checkpointing = denoise_fn.gradient_checkpointing if ( - denoise_fn.gradient_checkpointing is not None) else False - # init convnext denoiser - denoiser = ConvNext( - mel_channels=out_dims, - dim=self.cn_chans, - mlp_factor=self.mlp_factor, - condition_dim=n_hidden, - num_layers=self.cn_layers, - dilation_cycle=self.cn_dilation_cycle, - gradient_checkpointing=self.gradient_checkpointing - ) - - elif denoise_fn.type == 'NaiveV2Diff': - # catch None - self.cn_layers = denoise_fn.cn_layers if (denoise_fn.cn_layers is not None) else 20 - self.cn_chans = denoise_fn.cn_chans if (denoise_fn.cn_chans is not None) else 384 - self.use_mlp = denoise_fn.use_mlp if (denoise_fn.use_mlp is not None) else True - self.mlp_factor = denoise_fn.mlp_factor if (denoise_fn.mlp_factor is not None) else 4 - self.expansion_factor = denoise_fn.expansion_factor if (denoise_fn.expansion_factor is not None) else 2 - self.kernel_size = denoise_fn.kernel_size if (denoise_fn.kernel_size is not None) else 31 - self.conv_only = denoise_fn.conv_only if (denoise_fn.conv_only is not None) else True - self.wavenet_like = denoise_fn.wavenet_like if (denoise_fn.wavenet_like is not None) else False - self.use_norm = denoise_fn.use_norm if (denoise_fn.use_norm is not None) else True - self.conv_model_type = denoise_fn.conv_model_type if (denoise_fn.conv_model_type is not None) else 'mode1' - # init convnext denoiser - denoiser = NaiveV2Diff( - mel_channels=out_dims, - dim=self.cn_chans, - use_mlp=self.use_mlp, - mlp_factor=self.mlp_factor, - condition_dim=n_hidden, - num_layers=self.cn_layers, - expansion_factor=self.expansion_factor, - kernel_size=self.kernel_size, - conv_only=self.conv_only, - wavenet_like=self.wavenet_like, - use_norm=self.use_norm, - conv_model_type=denoise_fn.conv_model_type - ) - - else: - raise TypeError(" [X] Unknow denoise_fn") - self.denoise_fn_type = denoise_fn.type - - # catch None - self.z_rate = z_rate - self.mean_only = mean_only if (mean_only is not None) else False - self.max_beta = max_beta if (max_beta is not None) else 0.02 - self.spec_min = spec_min if (spec_min is not None) else -12 - self.spec_max = spec_max if (spec_max is not None) else 2 - - # init embed - self.unit_embed = nn.Linear(input_channel, n_hidden) - self.f0_embed = nn.Linear(1, n_hidden) - self.volume_embed = nn.Linear(1, n_hidden) - if use_pitch_aug: - self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) - else: - self.aug_shift_embed = None - self.n_spk = n_spk - self.use_speaker_encoder = use_speaker_encoder - if use_speaker_encoder: - self.spk_embed = nn.Linear(speaker_encoder_out_channels, n_hidden, bias=False) - else: - if n_spk is not None and n_spk > 1: - self.spk_embed = nn.Embedding(n_spk, n_hidden) - - # init diffusion - self.decoder = GaussianDiffusion( - denoiser, - out_dims=out_dims, - max_beta=self.max_beta, - spec_min=self.spec_min, - spec_max=self.spec_max) - - def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, - gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None, spk_emb_dict=None, use_vae=False): - ''' - input: - B x n_frames x n_unit - return: - dict of B x n_frames x feat - ''' - - # embed - x = self.unit_embed(units) + self.f0_embed((1 + f0 / 700).log()) + self.volume_embed(volume) - if self.use_speaker_encoder: - if spk_mix_dict is not None: - assert spk_emb_dict is not None - for k, v in spk_mix_dict.items(): - spk_id_torch = spk_emb_dict[str(k)] - spk_id_torch = np.tile(spk_id_torch, (len(units), 1)) - spk_id_torch = torch.from_numpy(spk_id_torch).float().to(units.device) - x = x + v * self.spk_embed(spk_id_torch) - else: - x = x + self.spk_embed(spk_emb) - else: - if self.n_spk is not None and self.n_spk > 1: - if spk_mix_dict is not None: - for k, v in spk_mix_dict.items(): - spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) - x = x + v * self.spk_embed(spk_id_torch - 1) - else: - x = x + self.spk_embed(spk_id - 1) - if self.aug_shift_embed is not None and aug_shift is not None: - x = x + self.aug_shift_embed(aug_shift / 5) - - # sample z or mean only - if use_vae and (gt_spec is not None): - gt_spec = get_z(gt_spec, mean_only=self.mean_only) - if (self.z_rate is not None) and (self.z_rate != 0): - gt_spec = gt_spec * self.z_rate # scale z - - # conditional mask - if self.mask_cond_ratio is not None: - if not infer: - if self.denoise_fn_type == 'NaiveV2Diff': - self.decoder.denoise_fn.mask_cond_ratio = self.mask_cond_ratio - - # diffusion - x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, - use_tqdm=use_tqdm) - - if self.mask_cond_ratio is not None: - if self.denoise_fn_type == 'NaiveV2Diff': - self.decoder.denoise_fn.mask_cond_ratio = None - - if (self.z_rate is not None) and (self.z_rate != 0): - x = x / self.z_rate # scale z - - return x - - class Unit2Mel(nn.Module): - # old version def __init__( self, input_channel, @@ -385,17 +139,16 @@ def __init__( self.spk_embed = nn.Embedding(n_spk, n_hidden) # diffusion - self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden, 1, 3, False), - out_dims=out_dims, max_beta=0.02, spec_min=-12, spec_max=2) + self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims) def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=None, use_tqdm=True, - spk_emb=None, spk_emb_dict=None, use_vae=False): + spk_emb=None, spk_emb_dict=None): ''' - input: + input: B x n_frames x n_unit - return: + return: dict of B x n_frames x feat ''' diff --git a/diffusion/vocoder.py b/diffusion/vocoder.py index 5ec1575..6a4c9f8 100644 --- a/diffusion/vocoder.py +++ b/diffusion/vocoder.py @@ -2,8 +2,6 @@ from nsf_hifigan.nvSTFT import STFT from nsf_hifigan.models import load_model, load_config from torchaudio.transforms import Resample -import os -from encoder.hifi_vaegan import InferModel class Vocoder: @@ -16,8 +14,6 @@ def __init__(self, vocoder_type, vocoder_ckpt, device=None): self.vocoder = NsfHifiGAN(vocoder_ckpt, device=device) elif vocoder_type == 'nsf-hifigan-log10': self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device=device) - elif vocoder_type == 'hifivaegan': - self.vocoder = HiFiVAEGAN(vocoder_ckpt, device=device) else: raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") @@ -98,41 +94,3 @@ def forward(self, mel, f0): c = 0.434294 * mel.transpose(1, 2) audio = self.model(c, f0) return audio - - -class HiFiVAEGAN(torch.nn.Module): - def __init__(self, model_path, device=None): - super().__init__() - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device - self.model_path = model_path - self.config_path = os.path.join(os.path.split(model_path)[0], 'config.json') - self.model = InferModel(self.config_path, self.model_path, device=device) - - def sample_rate(self): - return self.model.sr - - def hop_size(self): - return self.model.hop_size - - def dimension(self): - return self.model.inter_channels - - def extract(self, audio, keyshift=0, only_z=False): - if audio.shape[-1] % self.model.hop_size == 0: - audio = torch.cat((audio, torch.zeros_like(audio[:, :1])), dim=-1) - if keyshift != 0: - raise ValueError("HiFiVAEGAN could not use keyshift!") - with torch.no_grad(): - z, m, logs = self.model.encode(audio) - if only_z: - return z.transpose(1, 2) - mel = torch.stack((m.transpose(-1, -2), logs.transpose(-1, -2)), dim=-1) - return mel - - def forward(self, mel, f0): - with torch.no_grad(): - z = mel.transpose(1, 2) - audio = self.model.decode(z) - return audio diff --git a/diffusion/wavenet.py b/diffusion/wavenet.py index 8ad7610..3d48c7e 100644 --- a/diffusion/wavenet.py +++ b/diffusion/wavenet.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import Mish -from transformers.models.roformer.modeling_roformer import RoFormerEncoder, RoFormerConfig class Conv1d(torch.nn.Conv1d): @@ -30,14 +29,14 @@ def forward(self, x): class ResidualBlock(nn.Module): - def __init__(self, encoder_hidden, residual_channels, dilation, kernel_size=3): + def __init__(self, encoder_hidden, residual_channels, dilation): super().__init__() self.residual_channels = residual_channels self.dilated_conv = nn.Conv1d( residual_channels, 2 * residual_channels, - kernel_size=kernel_size, - padding=dilation if (kernel_size == 3) else int((kernel_size-1) * dilation / 2), + kernel_size=3, + padding=dilation, dilation=dilation ) self.diffusion_projection = nn.Linear(residual_channels, residual_channels) @@ -63,8 +62,7 @@ def forward(self, x, conditioner, diffusion_step): class WaveNet(nn.Module): - def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256, dilation=1, kernel_size=3, - transformer_use=False, transformer_roformer_use=False, transformer_n_layers=2, transformer_n_head=4): + def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): super().__init__() self.input_projection = Conv1d(in_dims, n_chans, 1) self.diffusion_embedding = SinusoidalPosEmb(n_chans) @@ -77,35 +75,10 @@ def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256, dilation ResidualBlock( encoder_hidden=n_hidden, residual_channels=n_chans, - dilation=(2 ** (i % dilation)) if (dilation != 1) else 1, - kernel_size=kernel_size + dilation=1 ) for i in range(n_layers) ]) - self.transformer_roformer_use = transformer_roformer_use if (transformer_roformer_use is not None) else False - if transformer_use: - if transformer_roformer_use: - self.transformer = RoFormerEncoder( - RoFormerConfig( - hidden_size=n_chans, - max_position_embeddings=4096, - num_attention_heads=transformer_n_head, - num_hidden_layers=transformer_n_layers, - add_cross_attention=False - ) - ) - else: - transformer_layer = nn.TransformerEncoderLayer( - d_model=n_chans, - nhead=transformer_n_head, - dim_feedforward=n_chans * 4, - dropout=0.1, - activation='gelu' - ) - self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=transformer_n_layers) - else: - self.transformer = None - self.skip_projection = Conv1d(n_chans, n_chans, 1) self.output_projection = Conv1d(n_chans, in_dims, 1) nn.init.zeros_(self.output_projection.weight) @@ -131,10 +104,5 @@ def forward(self, spec, diffusion_step, cond): x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) x = self.skip_projection(x) x = F.relu(x) - if self.transformer is not None: - if self.transformer_roformer_use: - x = self.transformer(x.transpose(1, 2))[0].transpose(1, 2) - else: - x = self.transformer(x.transpose(1, 2)).transpose(1, 2) x = self.output_projection(x) # [B, mel_bins, T] return x[:, None, :, :]