Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using e as the base number of the log mel-spectrogram #175

Merged
merged 6 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions augmentation/spec_stretch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from copy import deepcopy

import librosa
import numpy as np
import torch

from basics.base_augmentation import BaseAugmentation, require_same_keys
from basics.base_pe import BasePE
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
from modules.fastspeech.tts_modules import LengthRegulator
from modules.vocoders.registry import VOCODERS
from utils.binarizer_utils import get_mel2ph_torch
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve

Expand All @@ -27,14 +27,13 @@ def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
@require_same_keys
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
aug_item = deepcopy(item)
if hparams['vocoder'] in VOCODERS:
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(
aug_item['wav_fn'], keyshift=key_shift, speed=speed
)
else:
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(
aug_item['wav_fn'], keyshift=key_shift, speed=speed
)
waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
mel = get_mel_torch(
waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
fmin=hparams['fmin'], fmax=hparams['fmax'], mel_base=hparams['mel_base'],
keyshift=key_shift, speed=speed, device=self.device
)

aug_item['mel'] = mel

Expand All @@ -48,7 +47,7 @@ def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None)
).cpu().numpy()

f0, _ = self.pe.get_pitch(
wav, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
speed=speed, interp_uv=True
)
Expand Down
9 changes: 0 additions & 9 deletions basics/base_vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,3 @@ def spec2wav(self, mel, **kwargs):
"""

raise NotImplementedError()

@staticmethod
def wav2spec(wav_fn):
"""

:param wav_fn: str
:return: wav, mel: [T, 80]
"""
raise NotImplementedError()
3 changes: 2 additions & 1 deletion configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
dictionary: dictionaries/opencpop-extension.txt
spec_min: [-5]
spec_max: [0]
mel_vmin: -6. #-6.
mel_vmin: -6.
mel_vmax: 1.5
mel_base: '10'
energy_smooth_width: 0.12
breathiness_smooth_width: 0.12
voicing_smooth_width: 0.12
Expand Down
2 changes: 1 addition & 1 deletion deployment/exporters/acoustic_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def export_attachments(self, path: Path):
dsconfig['num_mel_bins'] = hparams['audio_num_mel_bins']
dsconfig['mel_fmin'] = hparams['fmin']
dsconfig['mel_fmax'] = hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2
dsconfig['mel_base'] = '10'
dsconfig['mel_base'] = str(hparams.get('mel_base', '10'))
dsconfig['mel_scale'] = 'slaney'
config_path = path / 'dsconfig.yaml'
with open(config_path, 'w', encoding='utf8') as fw:
Expand Down
4 changes: 2 additions & 2 deletions deployment/exporters/nsf_hifigan_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build_model(self) -> nn.Module:
config_path = self.model_path.with_name('config.json')
with open(config_path, 'r', encoding='utf8') as f:
config = json.load(f)
model = NSFHiFiGANONNX(config).eval().to(self.device)
model = NSFHiFiGANONNX(config, mel_base=hparams.get('mel_base', '10')).eval().to(self.device)
load_ckpt(model.generator, str(self.model_path),
prefix_in_ckpt=None, key_in_ckpt='generator',
strict=True, device=self.device)
Expand Down Expand Up @@ -65,7 +65,7 @@ def export_attachments(self, path: Path):
'num_mel_bins': hparams['audio_num_mel_bins'],
'mel_fmin': hparams['fmin'],
'mel_fmax': hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2,
'mel_base': '10',
'mel_base': str(hparams.get('mel_base', '10')),
'mel_scale': 'slaney',
}, fw, sort_keys=False)
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
Expand Down
9 changes: 7 additions & 2 deletions deployment/modules/nsf_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@

# noinspection SpellCheckingInspection
class NSFHiFiGANONNX(torch.nn.Module):
def __init__(self, attrs: dict):
def __init__(self, attrs: dict, mel_base='e'):
super().__init__()
self.mel_base = str(mel_base)
assert self.mel_base in ['e', '10'], "mel_base must be 'e', '10' or 10."
self.generator = Generator(AttrDict(attrs))

def forward(self, mel: torch.Tensor, f0: torch.Tensor):
mel = mel.transpose(1, 2) * 2.30259
mel = mel.transpose(1, 2)
if self.mel_base != 'e':
# log10 to log mel
mel = mel * 2.30259
wav = self.generator(mel, f0)
return wav.squeeze(1)
12 changes: 10 additions & 2 deletions inference/val_nsf_hifigan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import sys

import librosa
import numpy as np
import resampy
import torch
import torchcrepe
import tqdm

from utils.binarizer_utils import get_pitch_parselmouth
from utils.binarizer_utils import get_pitch_parselmouth, get_mel_torch
from modules.vocoders.nsf_hifigan import NsfHifiGAN
from utils.infer_utils import save_wav
from utils.hparams import set_hparams, hparams
Expand Down Expand Up @@ -60,7 +61,14 @@ def get_pitch(wav_data, mel, hparams, threshold=0.3):
for filename in tqdm.tqdm(os.listdir(in_path)):
if not filename.endswith('.wav'):
continue
wav, mel = vocoder.wav2spec(os.path.join(in_path, filename))
wav, _ = librosa.load(os.path.join(in_path, filename), sr=hparams['audio_sample_rate'], mono=True)
mel = get_mel_torch(
wav, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
fmin=hparams['fmin'], fmax=hparams['fmax'], mel_base=hparams['mel_base'],
device=device
)

f0, _ = get_pitch_parselmouth(
wav, samplerate=hparams['audio_sample_rate'], length=len(mel),
hop_size=hparams['hop_size']
Expand Down
11 changes: 0 additions & 11 deletions modules/nsf_hifigan/nvSTFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
import torch
import torch.utils.data
import numpy as np
import librosa
from librosa.filters import mel as librosa_mel_fn
import torch.nn.functional as F


def load_wav_to_torch(full_path, target_sr=None):
data, sr = librosa.load(full_path, sr=target_sr, mono=True)
return torch.from_numpy(data), sr


def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)

Expand Down Expand Up @@ -96,8 +90,3 @@ def get_mel(self, y, keyshift=0, speed=1, center=False):
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)

return spec

def __call__(self, audiopath):
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
return spect
131 changes: 18 additions & 113 deletions modules/vocoders/ddsp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import pathlib

import librosa
import numpy as np
import torch
import torch.nn.functional as F
import yaml
import numpy as np
from librosa.filters import mel as librosa_mel_fn

from basics.base_vocoder import BaseVocoder
from modules.vocoders.registry import register_vocoder
from utils.hparams import hparams
Expand Down Expand Up @@ -35,83 +34,6 @@ def load_model(model_path: pathlib.Path, device='cpu'):
return model, args


class Audio2Mel(torch.nn.Module):
def __init__(
self,
hop_length,
sampling_rate,
n_mel_channels,
win_length,
n_fft=None,
mel_fmin=0,
mel_fmax=None,
clamp=1e-5
):
super().__init__()
n_fft = win_length if n_fft is None else n_fft
self.hann_window = {}
mel_basis = librosa_mel_fn(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
self.clamp = clamp

def forward(self, audio, keyshift=0, speed=1):
'''
audio: B x C x T
log_mel_spec: B x T_ x C x n_mel
'''
factor = 2 ** (keyshift / 12)
n_fft_new = int(np.round(self.n_fft * factor))
win_length_new = int(np.round(self.win_length * factor))
hop_length_new = int(np.round(self.hop_length * speed))

keyshift_key = str(keyshift) + '_' + str(audio.device)
if keyshift_key not in self.hann_window:
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)

B, C, T = audio.shape
audio = audio.reshape(B * C, T)
fft = torch.stft(
audio,
n_fft=n_fft_new,
hop_length=hop_length_new,
win_length=win_length_new,
window=self.hann_window[keyshift_key],
center=True,
return_complex=False)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)

if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new

mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log10(torch.clamp(mel_output, min=self.clamp))

# log_mel_spec: B x C, M, T
T_ = log_mel_spec.shape[-1]
log_mel_spec = log_mel_spec.reshape(B, C, self.n_mel_channels, T_)
log_mel_spec = log_mel_spec.permute(0, 3, 1, 2)

# print('og_mel_spec:', log_mel_spec.shape)
log_mel_spec = log_mel_spec.squeeze(2) # mono
return log_mel_spec


@register_vocoder
class DDSP(BaseVocoder):
def __init__(self, device='cpu'):
Expand Down Expand Up @@ -149,8 +71,15 @@ def spec2wav_torch(self, mel, f0): # mel: [B, T, bins] f0: [B, T]
print('Mismatch parameters: hparams[\'fmax\']=', hparams['fmax'], '!=', self.args.data.mel_fmax,
'(vocoder)')
with torch.no_grad():
f0 = f0.unsqueeze(-1)
signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device))
mel = mel.to(self.device)
mel_base = hparams.get('mel_base', 10)
if mel_base != 'e':
assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10."
else:
# log mel to log10 mel
mel = 0.434294 * mel
f0 = f0.unsqueeze(-1).to(self.device)
signal, _, (s_h, s_n) = self.model(mel, f0)
signal = signal.view(-1)
return signal

Expand Down Expand Up @@ -178,38 +107,14 @@ def spec2wav(self, mel, f0):
'(vocoder)')
with torch.no_grad():
mel = torch.FloatTensor(mel).unsqueeze(0).to(self.device)
mel_base = hparams.get('mel_base', 10)
if mel_base != 'e':
assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10."
else:
# log mel to log10 mel
mel = 0.434294 * mel
f0 = torch.FloatTensor(f0).unsqueeze(0).unsqueeze(-1).to(self.device)
signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device))
signal, _, (s_h, s_n) = self.model(mel, f0)
signal = signal.view(-1)
wav_out = signal.cpu().numpy()
return wav_out

@staticmethod
def wav2spec(inp_path, keyshift=0, speed=1, device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sampling_rate = hparams['audio_sample_rate']
n_mel_channels = hparams['audio_num_mel_bins']
n_fft = hparams['fft_size']
win_length = hparams['win_size']
hop_length = hparams['hop_size']
mel_fmin = hparams['fmin']
mel_fmax = hparams['fmax']

# load input
x, _ = librosa.load(inp_path, sr=sampling_rate)
x_t = torch.from_numpy(x).float().to(device)
x_t = x_t.unsqueeze(0).unsqueeze(0) # (T,) --> (1, 1, T)

# mel analysis
mel_extractor = Audio2Mel(
hop_length=hop_length,
sampling_rate=sampling_rate,
n_mel_channels=n_mel_channels,
win_length=win_length,
n_fft=n_fft,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax).to(device)

mel = mel_extractor(x_t, keyshift=keyshift, speed=speed)
return x, mel.squeeze(0).cpu().numpy()
Loading