|
| 1 | +from collections import deque |
| 2 | +from functools import partial |
| 3 | + |
| 4 | +import math |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from torch import nn |
| 8 | +import torch.nn.functional as F |
| 9 | +from torch.nn import Conv1d |
| 10 | +from modules.commons.common_layers import Mish |
| 11 | +from modules.encoder import SvcEncoder |
| 12 | +from utils.hparams import hparams |
| 13 | + |
| 14 | + |
| 15 | +def exists(x): |
| 16 | + return x is not None |
| 17 | + |
| 18 | + |
| 19 | +def extract(a, t): |
| 20 | + return a[t].reshape((1, 1, 1, 1)) |
| 21 | + |
| 22 | + |
| 23 | +def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)): |
| 24 | + betas = np.linspace(1e-4, max_beta, timesteps) |
| 25 | + return betas |
| 26 | + |
| 27 | + |
| 28 | +def cosine_beta_schedule(timesteps, s=0.008): |
| 29 | + steps = timesteps + 1 |
| 30 | + x = np.linspace(0, steps, steps) |
| 31 | + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 |
| 32 | + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| 33 | + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| 34 | + return np.clip(betas, a_min=0, a_max=0.999) |
| 35 | + |
| 36 | + |
| 37 | +beta_schedule = { |
| 38 | + "cosine": cosine_beta_schedule, |
| 39 | + "linear": linear_beta_schedule, |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +def extract_1(a, t): |
| 44 | + return a[t].reshape((1, 1, 1, 1)) |
| 45 | + |
| 46 | + |
| 47 | +def predict_stage0(noise_pred, noise_pred_prev): |
| 48 | + return (noise_pred |
| 49 | + + noise_pred_prev) / 2 |
| 50 | + |
| 51 | + |
| 52 | +def predict_stage1(noise_pred, noise_list): |
| 53 | + return (noise_pred * 3 |
| 54 | + - noise_list[-1]) / 2 |
| 55 | + |
| 56 | + |
| 57 | +def predict_stage2(noise_pred, noise_list): |
| 58 | + return (noise_pred * 23 |
| 59 | + - noise_list[-1] * 16 |
| 60 | + + noise_list[-2] * 5) / 12 |
| 61 | + |
| 62 | + |
| 63 | +def predict_stage3(noise_pred, noise_list): |
| 64 | + return (noise_pred * 55 |
| 65 | + - noise_list[-1] * 59 |
| 66 | + + noise_list[-2] * 37 |
| 67 | + - noise_list[-3] * 9) / 24 |
| 68 | + |
| 69 | + |
| 70 | +class SinusoidalPosEmb(nn.Module): |
| 71 | + def __init__(self, dim): |
| 72 | + super().__init__() |
| 73 | + self.dim = dim |
| 74 | + self.half_dim = dim // 2 |
| 75 | + self.emb = 9.21034037 / (self.half_dim - 1) |
| 76 | + self.emb = torch.exp(torch.arange(self.half_dim) * torch.tensor(-self.emb)).unsqueeze(0) |
| 77 | + self.emb = self.emb.cuda() |
| 78 | + |
| 79 | + def forward(self, x): |
| 80 | + emb = self.emb * x |
| 81 | + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
| 82 | + return emb |
| 83 | + |
| 84 | + |
| 85 | +class ResidualBlock(nn.Module): |
| 86 | + def __init__(self, encoder_hidden, residual_channels, dilation): |
| 87 | + super().__init__() |
| 88 | + self.residual_channels = residual_channels |
| 89 | + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) |
| 90 | + self.diffusion_projection = nn.Linear(residual_channels, residual_channels) |
| 91 | + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) |
| 92 | + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) |
| 93 | + |
| 94 | + def forward(self, x, conditioner, diffusion_step): |
| 95 | + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) |
| 96 | + conditioner = self.conditioner_projection(conditioner) |
| 97 | + y = x + diffusion_step |
| 98 | + y = self.dilated_conv(y) + conditioner |
| 99 | + |
| 100 | + gate, filter_1 = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) |
| 101 | + |
| 102 | + y = torch.sigmoid(gate) * torch.tanh(filter_1) |
| 103 | + y = self.output_projection(y) |
| 104 | + |
| 105 | + residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) |
| 106 | + |
| 107 | + return (x + residual) / 1.41421356, skip |
| 108 | + |
| 109 | + |
| 110 | +class DiffNet(nn.Module): |
| 111 | + def __init__(self, in_dims=80): |
| 112 | + super().__init__() |
| 113 | + self.encoder_hidden = hparams['hidden_size'] |
| 114 | + self.residual_layers = hparams['residual_layers'] |
| 115 | + self.residual_channels = hparams['residual_channels'] |
| 116 | + self.dilation_cycle_length = hparams['dilation_cycle_length'] |
| 117 | + self.input_projection = Conv1d(in_dims, self.residual_channels, 1) |
| 118 | + self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels) |
| 119 | + dim = self.residual_channels |
| 120 | + self.mlp = nn.Sequential( |
| 121 | + nn.Linear(dim, dim * 4), |
| 122 | + Mish(), |
| 123 | + nn.Linear(dim * 4, dim) |
| 124 | + ) |
| 125 | + self.residual_layers = nn.ModuleList([ |
| 126 | + ResidualBlock(self.encoder_hidden, self.residual_channels, 2 ** (i % self.dilation_cycle_length)) |
| 127 | + for i in range(self.residual_layers) |
| 128 | + ]) |
| 129 | + self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1) |
| 130 | + self.output_projection = Conv1d(self.residual_channels, in_dims, 1) |
| 131 | + nn.init.zeros_(self.output_projection.weight) |
| 132 | + |
| 133 | + def forward(self, spec, diffusion_step, cond): |
| 134 | + x = spec.squeeze(0) |
| 135 | + x = self.input_projection(x) # x [B, residual_channel, T] |
| 136 | + x = F.relu(x) |
| 137 | + # skip = torch.randn_like(x) |
| 138 | + diffusion_step = diffusion_step.float() |
| 139 | + diffusion_step = self.diffusion_embedding(diffusion_step) |
| 140 | + diffusion_step = self.mlp(diffusion_step) |
| 141 | + |
| 142 | + x, skip = self.residual_layers[0](x, cond, diffusion_step) |
| 143 | + # noinspection PyTypeChecker |
| 144 | + for layer in self.residual_layers[1:]: |
| 145 | + x, skip_connection = layer.forward(x, cond, diffusion_step) |
| 146 | + skip = skip + skip_connection |
| 147 | + x = skip / math.sqrt(len(self.residual_layers)) |
| 148 | + x = self.skip_projection(x) |
| 149 | + x = F.relu(x) |
| 150 | + x = self.output_projection(x) # [B, 80, T] |
| 151 | + return x.unsqueeze(1) |
| 152 | + |
| 153 | + |
| 154 | +class GaussianDiffusion(nn.Module): |
| 155 | + def __init__(self, phone_encoder, out_dims, denoise_fn, |
| 156 | + timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, |
| 157 | + spec_max=None): |
| 158 | + super().__init__() |
| 159 | + self.denoise_fn = DiffNet(out_dims) |
| 160 | + self.fs2 = SvcEncoder(phone_encoder, out_dims) |
| 161 | + self.mel_bins = out_dims |
| 162 | + |
| 163 | + if exists(betas): |
| 164 | + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas |
| 165 | + else: |
| 166 | + if 'schedule_type' in hparams.keys(): |
| 167 | + betas = beta_schedule[hparams['schedule_type']](timesteps) |
| 168 | + else: |
| 169 | + betas = cosine_beta_schedule(timesteps) |
| 170 | + |
| 171 | + alphas = 1. - betas |
| 172 | + alphas_cumprod = np.cumprod(alphas, axis=0) |
| 173 | + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) |
| 174 | + |
| 175 | + timesteps, = betas.shape |
| 176 | + self.num_timesteps = int(timesteps) |
| 177 | + self.K_step = K_step |
| 178 | + self.loss_type = loss_type |
| 179 | + |
| 180 | + self.noise_list = deque(maxlen=4) |
| 181 | + |
| 182 | + to_torch = partial(torch.tensor, dtype=torch.float32) |
| 183 | + |
| 184 | + self.register_buffer('betas', to_torch(betas)) |
| 185 | + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) |
| 186 | + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) |
| 187 | + |
| 188 | + # calculations for diffusion q(x_t | x_{t-1}) and others |
| 189 | + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) |
| 190 | + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) |
| 191 | + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) |
| 192 | + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) |
| 193 | + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) |
| 194 | + |
| 195 | + # calculations for posterior q(x_{t-1} | x_t, x_0) |
| 196 | + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
| 197 | + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) |
| 198 | + self.register_buffer('posterior_variance', to_torch(posterior_variance)) |
| 199 | + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain |
| 200 | + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) |
| 201 | + self.register_buffer('posterior_mean_coef1', to_torch( |
| 202 | + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) |
| 203 | + self.register_buffer('posterior_mean_coef2', to_torch( |
| 204 | + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) |
| 205 | + |
| 206 | + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) |
| 207 | + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) |
| 208 | + self.mel_vmin = hparams['mel_vmin'] |
| 209 | + self.mel_vmax = hparams['mel_vmax'] |
| 210 | + |
| 211 | + def get_x_pred(self, x_1, noise_t, t_1, t_prev): |
| 212 | + a_t = extract(self.alphas_cumprod, t_1) |
| 213 | + a_prev = extract(self.alphas_cumprod, t_prev) |
| 214 | + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() |
| 215 | + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x_1 - 1 / ( |
| 216 | + a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) |
| 217 | + x_pred = x_1 + x_delta |
| 218 | + return x_pred |
| 219 | + |
| 220 | + def forward(self, hubert, mel2ph=None, spk_embed=None, f0=None, initial_noise=None, speedup=None): |
| 221 | + decoder_inp, f0_denorm = self.fs2(hubert, mel2ph, spk_embed, f0) |
| 222 | + cond = decoder_inp.transpose(1, 2) |
| 223 | + x = initial_noise |
| 224 | + pndms = speedup[0] |
| 225 | + device = cond.device |
| 226 | + n_frames = cond.shape[2] |
| 227 | + step_range = torch.arange(0, self.K_step, pndms, dtype=torch.long, device=device).flip(0) |
| 228 | + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) |
| 229 | + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) |
| 230 | + for t in step_range: |
| 231 | + t_1 = torch.full((1,), t, device=device, dtype=torch.long) |
| 232 | + noise_pred = self.denoise_fn(x, t_1, cond) |
| 233 | + t_prev = t_1 - pndms |
| 234 | + t_prev = t_prev * (t_prev > 0) |
| 235 | + if plms_noise_stage == 0: |
| 236 | + x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev) |
| 237 | + noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond) |
| 238 | + noise_pred_prime = predict_stage0(noise_pred, noise_pred_prev) |
| 239 | + elif plms_noise_stage == 1: |
| 240 | + noise_pred_prime = predict_stage1(noise_pred, noise_list) |
| 241 | + elif plms_noise_stage == 2: |
| 242 | + noise_pred_prime = predict_stage2(noise_pred, noise_list) |
| 243 | + else: |
| 244 | + noise_pred_prime = predict_stage3(noise_pred, noise_list) |
| 245 | + noise_pred = noise_pred.unsqueeze(0) |
| 246 | + if plms_noise_stage < 3: |
| 247 | + noise_list = torch.cat((noise_list, noise_pred), dim=0) |
| 248 | + plms_noise_stage = plms_noise_stage + 1 |
| 249 | + else: |
| 250 | + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) |
| 251 | + x = self.get_x_pred(x, noise_pred_prime, t_1, t_prev) |
| 252 | + |
| 253 | + x = x.squeeze(1).permute(0, 2, 1) |
| 254 | + d = (self.spec_max - self.spec_min) / 2 |
| 255 | + m = (self.spec_max + self.spec_min) / 2 |
| 256 | + mel_out = x * d + m |
| 257 | + # mel_out[mel_out > self.mel_vmax] = self.mel_vmax |
| 258 | + # mel_out[mel_out < self.mel_vmin] = self.mel_vmin |
| 259 | + mel_out = mel_out * 2.30259 |
| 260 | + return mel_out.transpose(2, 1), f0_denorm |
0 commit comments