diff --git a/.github/workflows/tts_tests2.yml b/.github/workflows/tts_tests2.yml new file mode 100644 index 0000000000..f64433f8df --- /dev/null +++ b/.github/workflows/tts_tests2.yml @@ -0,0 +1,53 @@ +name: tts-tests2 + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.9, "3.10", "3.11"] + experimental: [false] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: set ENV + run: export TRAINER_TELEMETRY=0 + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + sudo apt-get install espeak + sudo apt-get install espeak-ng + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Replace scarf urls + run: | + sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: make test_tts2 diff --git a/Makefile b/Makefile index 9f2008dc67..ab992ec52e 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ test_vocoder: ## run vocoder tests. test_tts: ## run tts tests. nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests +test_tts2: ## run tts tests. + nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2 + test_aux: ## run aux tests. nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests ./run_bash_tests.sh diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 224715aacf..dbe7f99856 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -430,9 +430,9 @@ def main(): if tts_path is not None: wav = synthesizer.tts( args.text, - args.speaker_idx, - args.language_idx, - args.speaker_wav, + speaker_name=args.speaker_idx, + language_name=args.language_idx, + speaker_wav=args.speaker_wav, reference_wav=args.reference_wav, style_wav=args.capacitron_style_wav, style_text=args.capacitron_style_text, diff --git a/TTS/tts/configs/delightful_tts_config.py b/TTS/tts/configs/delightful_tts_config.py new file mode 100644 index 0000000000..50ab60af81 --- /dev/null +++ b/TTS/tts/configs/delightful_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig + + +@dataclass +class DelightfulTTSConfig(BaseTTSConfig): + """ + Configuration class for the DelightfulTTS model. + + Attributes: + model (str): Name of the model ("delightful_tts"). + audio (DelightfulTtsAudioConfig): Configuration for audio settings. + model_args (DelightfulTtsArgs): Configuration for model arguments. + use_attn_priors (bool): Whether to use attention priors. + vocoder (VocoderConfig): Configuration for the vocoder. + init_discriminator (bool): Whether to initialize the discriminator. + steps_to_start_discriminator (int): Number of steps to start the discriminator. + grad_clip (List[float]): Gradient clipping values. + lr_gen (float): Learning rate for the gan generator. + lr_disc (float): Learning rate for the gan discriminator. + lr_scheduler_gen (str): Name of the learning rate scheduler for the generator. + lr_scheduler_gen_params (dict): Parameters for the learning rate scheduler for the generator. + lr_scheduler_disc (str): Name of the learning rate scheduler for the discriminator. + lr_scheduler_disc_params (dict): Parameters for the learning rate scheduler for the discriminator. + scheduler_after_epoch (bool): Whether to schedule after each epoch. + optimizer (str): Name of the optimizer. + optimizer_params (dict): Parameters for the optimizer. + ssim_loss_alpha (float): Alpha value for the SSIM loss. + mel_loss_alpha (float): Alpha value for the mel loss. + aligner_loss_alpha (float): Alpha value for the aligner loss. + pitch_loss_alpha (float): Alpha value for the pitch loss. + energy_loss_alpha (float): Alpha value for the energy loss. + u_prosody_loss_alpha (float): Alpha value for the utterance prosody loss. + p_prosody_loss_alpha (float): Alpha value for the phoneme prosody loss. + dur_loss_alpha (float): Alpha value for the duration loss. + char_dur_loss_alpha (float): Alpha value for the character duration loss. + binary_align_loss_alpha (float): Alpha value for the binary alignment loss. + binary_loss_warmup_epochs (int): Number of warm-up epochs for the binary loss. + disc_loss_alpha (float): Alpha value for the discriminator loss. + gen_loss_alpha (float): Alpha value for the generator loss. + feat_loss_alpha (float): Alpha value for the feature loss. + vocoder_mel_loss_alpha (float): Alpha value for the vocoder mel loss. + multi_scale_stft_loss_alpha (float): Alpha value for the multi-scale STFT loss. + multi_scale_stft_loss_params (dict): Parameters for the multi-scale STFT loss. + return_wav (bool): Whether to return audio waveforms. + use_weighted_sampler (bool): Whether to use a weighted sampler. + weighted_sampler_attrs (dict): Attributes for the weighted sampler. + weighted_sampler_multipliers (dict): Multipliers for the weighted sampler. + r (int): Value for the `r` override. + compute_f0 (bool): Whether to compute F0 values. + f0_cache_path (str): Path to the F0 cache. + attn_prior_cache_path (str): Path to the attention prior cache. + num_speakers (int): Number of speakers. + use_speaker_embedding (bool): Whether to use speaker embedding. + speakers_file (str): Path to the speaker file. + speaker_embedding_channels (int): Number of channels for the speaker embedding. + language_ids_file (str): Path to the language IDs file. + """ + + model: str = "delightful_tts" + + # model specific params + audio: DelightfulTtsAudioConfig = field(default_factory=DelightfulTtsAudioConfig) + model_args: DelightfulTtsArgs = field(default_factory=DelightfulTtsArgs) + use_attn_priors: bool = True + + # vocoder + vocoder: VocoderConfig = field(default_factory=VocoderConfig) + init_discriminator: bool = True + + # optimizer + steps_to_start_discriminator: int = 200000 + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # acoustic model loss params + ssim_loss_alpha: float = 1.0 + mel_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + energy_loss_alpha: float = 1.0 + u_prosody_loss_alpha: float = 0.5 + p_prosody_loss_alpha: float = 0.5 + dur_loss_alpha: float = 1.0 + char_dur_loss_alpha: float = 0.01 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 10 + + # vocoder loss params + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + vocoder_mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # data loader params + return_wav: bool = True + use_weighted_sampler: bool = False + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + attn_prior_cache_path: str = None + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = None + d_vector_dim: int = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index df01d66323..c673c963b6 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -686,6 +686,7 @@ def __init__( self, samples: Union[List[List], List[Dict]], ap: "AudioProcessor", + audio_config=None, # pylint: disable=unused-argument verbose=False, cache_path: str = None, precompute_num_workers=0, diff --git a/TTS/tts/layers/delightful_tts/__init__.py b/TTS/tts/layers/delightful_tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py new file mode 100644 index 0000000000..c906b882e5 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -0,0 +1,563 @@ +### credit: https://github.com/dunky11/voicesmith +from typing import Callable, Dict, Tuple + +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.delightful_tts.conformer import Conformer +from TTS.tts.layers.delightful_tts.encoders import ( + PhonemeLevelProsodyEncoder, + UtteranceLevelProsodyEncoder, + get_mask_from_lengths, +) +from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor +from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding +from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor +from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask + + +class AcousticModel(torch.nn.Module): + def __init__( + self, + args: "ModelArgs", + tokenizer: "TTSTokenizer" = None, + speaker_manager: "SpeakerManager" = None, + ): + super().__init__() + self.args = args + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + + self.init_multispeaker(args) + # self.set_embedding_dims() + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb_dim = args.n_hidden_conformer_encoder + self.encoder = Conformer( + dim=self.args.n_hidden_conformer_encoder, + n_layers=self.args.n_layers_conformer_encoder, + n_heads=self.args.n_heads_conformer_encoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_encoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + self.pitch_adaptor = PitchAdaptor( + n_input=self.args.n_hidden_conformer_encoder, + n_hidden=self.args.n_hidden_variance_adaptor, + n_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + self.energy_adaptor = EnergyAdaptor( + channels_in=self.args.n_hidden_conformer_encoder, + channels_hidden=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, + in_key_channels=self.args.n_hidden_conformer_encoder, + ) + + self.duration_predictor = VariancePredictor( + channels_in=self.args.n_hidden_conformer_encoder, + channels=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_u=self.args.bottleneck_size_u_reference_encoder, + token_num=self.args.token_num_reference_encoder, + ) + + self.utterance_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_u_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_p=self.args.bottleneck_size_p_reference_encoder, + n_heads=self.args.n_heads_conformer_encoder, + ) + + self.phoneme_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_p_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.u_bottle_out = nn.Linear( + self.args.bottleneck_size_u_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + + self.u_norm = nn.InstanceNorm1d(self.args.bottleneck_size_u_reference_encoder) + self.p_bottle_out = nn.Linear( + self.args.bottleneck_size_p_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + self.p_norm = nn.InstanceNorm1d( + self.args.bottleneck_size_p_reference_encoder, + ) + self.decoder = Conformer( + dim=self.args.n_hidden_conformer_decoder, + n_layers=self.args.n_layers_conformer_decoder, + n_heads=self.args.n_heads_conformer_decoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_decoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_decoder, + lrelu_slope=self.args.lrelu_slope, + ) + + padding_idx = self.tokenizer.characters.pad_id + self.src_word_emb = EmbeddingPadded( + self.args.num_chars, self.args.n_hidden_conformer_encoder, padding_idx=padding_idx + ) + self.to_mel = nn.Linear( + self.args.n_hidden_conformer_decoder, + self.args.num_mels, + ) + + self.energy_scaler = torch.nn.BatchNorm1d(1, affine=False, track_running_stats=True, momentum=None) + self.energy_scaler.requires_grad_(False) + + def init_multispeaker(self, args: Coqpit): # pylint: disable=unused-argument + """Init for multi-speaker training.""" + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]) # .unsqueeze_(-1) + if g.ndim == 2: + g = g # .unsqueeze_(0) # pylint: disable=self-assigning-variable + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + # def set_embedding_dims(self): + # if self.embedded_speaker_dim > 0: + # self.embedding_dims = self.embedded_speaker_dim + # else: + # self.embedding_dims = 0 + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the linear scale durations. + + Args: + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def _expand_encoder_with_durations( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.IntTensor, + y_lengths: torch.IntTensor, + ): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) + return y_mask, o_en_ex, attn.transpose(1, 2) + + def _forward_aligner( + self, + x: torch.FloatTensor, + y: torch.FloatTensor, + x_mask: torch.IntTensor, + y_mask: torch.IntTensor, + attn_priors: torch.FloatTensor, + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + attn_priors (torch.FloatTensor): Prior for the aligner network map. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + - attn_priors: :math:`[B, T_de, T_en]` + + - aligner_durations: :math:`[B, T_en]` + - aligner_soft: :math:`[B, T_de, T_en]` + - aligner_logprob: :math:`[B, 1, T_de, T_en]` + - aligner_mas: :math:`[B, T_de, T_en]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # [B, 1, T_en, T_de] + aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors) + aligner_mas = maximum_path( + aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + aligner_durations = torch.sum(aligner_mas, -1).int() + aligner_soft = aligner_soft.squeeze(1) # [B, T_max2, T_max] + aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + return aligner_durations, aligner_soft, aligner_logprob, aligner_mas + + def average_utterance_prosody( # pylint: disable=no-self-use + self, u_prosody_pred: torch.Tensor, src_mask: torch.Tensor + ) -> torch.Tensor: + lengths = ((~src_mask) * 1.0).sum(1) + u_prosody_pred = u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1) + return u_prosody_pred + + def forward( + self, + tokens: torch.Tensor, + src_lens: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + pitches: torch.Tensor, + energies: torch.Tensor, + attn_priors: torch.Tensor, + use_ground_truth: bool = True, + d_vectors: torch.Tensor = None, + speaker_idx: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + src_mask = get_mask_from_lengths(src_lens) # [B, T_src] + mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel] + + # Token embeddings + token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden] + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Alignment network and durations + aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( + x=token_embeddings, + y=mels.transpose(1, 2), + x_mask=~src_mask[:, None], + y_mask=~mel_mask[:, None], + attn_priors=attn_priors, + ) + dr = aligner_durations # [B, T_en] + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + max(token_embeddings.shape[1], max(mel_lens)), + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_ref = self.u_norm(self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens)) + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred) + + p_prosody_ref = self.p_norm( + self.phoneme_prosody_encoder( + x=encoder_outputs, src_mask=src_mask, mels=mels, mel_lens=mel_lens, encoding=pos_encoding + ) + ) + p_prosody_pred = self.p_norm(self.phoneme_prosody_predictor(x=encoder_outputs, mask=src_mask)) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred) + + encoder_outputs_res = encoder_outputs + + pitch_pred, avg_pitch_target, pitch_emb = self.pitch_adaptor.get_pitch_embedding_train( + x=encoder_outputs, + target=pitches, + dr=dr, + mask=src_mask, + ) + + energy_pred, avg_energy_target, energy_emb = self.energy_adaptor.get_energy_embedding_train( + x=encoder_outputs, + target=energies, + dr=dr, + mask=src_mask, + ) + + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb + log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) + + mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] + ) + + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + x = self.to_mel(x) + + dr = torch.log(dr + 1) + + dr_pred = torch.exp(log_duration_prediction) - 1 + alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] + + return { + "model_outputs": x, + "pitch_pred": pitch_pred, + "pitch_target": avg_pitch_target, + "energy_pred": energy_pred, + "energy_target": avg_energy_target, + "u_prosody_pred": u_prosody_pred, + "u_prosody_ref": u_prosody_ref, + "p_prosody_pred": p_prosody_pred, + "p_prosody_ref": p_prosody_ref, + "alignments_dp": alignments_dp, + "alignments": alignments, # [B, T_de, T_en] + "aligner_soft": aligner_soft, + "aligner_mas": aligner_mas, + "aligner_durations": aligner_durations, + "aligner_logprob": aligner_logprob, + "dr_log_pred": log_duration_prediction.squeeze(1), # [B, T] + "dr_log_target": dr.squeeze(1), # [B, T] + "spk_emb": speaker_embedding, + } + + @torch.no_grad() + def inference( + self, + tokens: torch.Tensor, + speaker_idx: torch.Tensor, + p_control: float = None, # TODO # pylint: disable=unused-argument + d_control: float = None, # TODO # pylint: disable=unused-argument + d_vectors: torch.Tensor = None, + pitch_transform: Callable = None, + energy_transform: Callable = None, + ) -> torch.Tensor: + src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) + src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + token_embeddings = self.src_word_emb(tokens) + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + token_embeddings.shape[1], + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred).expand_as(encoder_outputs) + + p_prosody_pred = self.p_norm( + self.phoneme_prosody_predictor( + x=encoder_outputs, + mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred).expand_as(encoder_outputs) + + encoder_outputs_res = encoder_outputs + + pitch_emb_pred, pitch_pred = self.pitch_adaptor.get_pitch_embedding( + x=encoder_outputs, + mask=src_mask, + pitch_transform=pitch_transform, + pitch_mean=self.pitch_mean if hasattr(self, "pitch_mean") else None, + pitch_std=self.pitch_std if hasattr(self, "pitch_std") else None, + ) + + energy_emb_pred, energy_pred = self.energy_adaptor.get_energy_embedding( + x=encoder_outputs, mask=src_mask, energy_transform=energy_transform + ) + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb_pred + energy_emb_pred + + log_duration_pred = self.duration_predictor( + x=encoder_outputs_res.detach(), mask=src_mask + ) # [B, C_hidden, T_src] -> [B, T_src] + duration_pred = (torch.exp(log_duration_pred) - 1) * (~src_mask) * self.length_scale # -> [B, T_src] + duration_pred[duration_pred < 1] = 1.0 # -> [B, T_src] + duration_pred = torch.round(duration_pred) # -> [B, T_src] + mel_lens = duration_pred.sum(1) # -> [B,] + + _, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] + ) + + mel_mask = get_mask_from_lengths( + torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) + ) + + if encoder_outputs_ex.shape[1] > pos_encoding.shape[1]: + encoding = positional_encoding(self.emb_dim, encoder_outputs_ex.shape[2], device=tokens.device) + + # [B, C_hidden, T_src], [B, 1, T_src], [B, C_emb], [B, T_src, C_hidden] -> [B, C_hidden, T_src] + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + x = self.to_mel(x) + outputs = { + "model_outputs": x, + "alignments": alignments, + # "pitch": pitch_emb_pred, + "durations": duration_pred, + "pitch": pitch_pred, + "energy": energy_pred, + "spk_emb": speaker_embedding, + } + return outputs diff --git a/TTS/tts/layers/delightful_tts/conformer.py b/TTS/tts/layers/delightful_tts/conformer.py new file mode 100644 index 0000000000..b2175b3b96 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/conformer.py @@ -0,0 +1,450 @@ +### credit: https://github.com/dunky11/voicesmith +import math +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d +from TTS.tts.layers.delightful_tts.networks import GLUActivation + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class Conformer(nn.Module): + def __init__( + self, + dim: int, + n_layers: int, + n_heads: int, + speaker_embedding_dim: int, + p_dropout: float, + kernel_size_conv_mod: int, + lrelu_slope: float, + ): + """ + A Transformer variant that integrates both CNNs and Transformers components. + Conformer proposes a novel combination of self-attention and convolution, in which self-attention + learns the global interaction while the convolutions efficiently capture the local correlations. + + Args: + dim (int): Number of the dimensions for the model. + n_layers (int): Number of model layers. + n_heads (int): The number of attention heads. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + p_dropout (float): Probabilty of dropout. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by Conformer Encoder. + """ + super().__init__() + d_k = d_v = dim // n_heads + self.layer_stack = nn.ModuleList( + [ + ConformerBlock( + dim, + n_heads, + d_k, + d_v, + kernel_size_conv_mod=kernel_size_conv_mod, + dropout=p_dropout, + speaker_embedding_dim=speaker_embedding_dim, + lrelu_slope=lrelu_slope, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + speaker_embedding: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - speaker_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + + attn_mask = mask.view((mask.shape[0], 1, 1, mask.shape[1])) + for enc_layer in self.layer_stack: + x = enc_layer( + x, + mask=mask, + slf_attn_mask=attn_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + return x + + +class ConformerBlock(torch.nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + d_k: int, # pylint: disable=unused-argument + d_v: int, # pylint: disable=unused-argument + kernel_size_conv_mod: int, + speaker_embedding_dim: int, + dropout: float, + lrelu_slope: float = 0.3, + ): + """ + A Conformer block is composed of four modules stacked together, + A feed-forward module, a self-attention module, a convolution module, + and a second feed-forward module in the end. The block starts with two Feed forward + modules sandwiching the Multi-Headed Self-Attention module and the Conv module. + + Args: + d_model (int): The dimension of model + n_head (int): The number of attention heads. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + emotion_embedding_dim (int): Number of emotion embedding dimensions. + dropout (float): Probabilty of dropout. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **slf_attn_mask** (batch, 1, 1, time1): Tensor containing indices to be masked in self attention module + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by the Conformer Block. + """ + super().__init__() + if isinstance(speaker_embedding_dim, int): + self.conditioning = Conv1dGLU( + d_model=d_model, + kernel_size=kernel_size_conv_mod, + padding=kernel_size_conv_mod // 2, + embedding_dim=speaker_embedding_dim, + ) + + self.ff = FeedForward(d_model=d_model, dropout=dropout, kernel_size=3, lrelu_slope=lrelu_slope) + self.conformer_conv_1 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + self.ln = nn.LayerNorm(d_model) + self.slf_attn = ConformerMultiHeadedSelfAttention(d_model=d_model, num_heads=n_head, dropout_p=dropout) + self.conformer_conv_2 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + + def forward( + self, + x: torch.Tensor, + speaker_embedding: torch.Tensor, + mask: torch.Tensor, + slf_attn_mask: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - slf_attn_mask: :math: `[B, 1, 1, T_src]` + - speaker_embedding: :math: `[B, C]` + - emotion_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + if speaker_embedding is not None: + x = self.conditioning(x, embeddings=speaker_embedding) + x = self.ff(x) + x + x = self.conformer_conv_1(x) + x + res = x + x = self.ln(x) + x, _ = self.slf_attn(query=x, key=x, value=x, mask=slf_attn_mask, encoding=encoding) + x = x + res + x = x.masked_fill(mask.unsqueeze(-1), 0) + + x = self.conformer_conv_2(x) + x + return x + + +class FeedForward(nn.Module): + def __init__( + self, + d_model: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + expansion_factor: int = 4, + ): + """ + Feed Forward module for conformer block. + + Args: + d_model (int): The dimension of model. + kernel_size (int): Size of the kernels for conv layers. + dropout (float): probability of dropout. + expansion_factor (int): The factor by which to project the number of channels. + lrelu_slope (int): the negative slope factor for the leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the feed forward module. + """ + super().__init__() + self.dropout = nn.Dropout(dropout) + self.ln = nn.LayerNorm(d_model) + self.conv_1 = nn.Conv1d( + d_model, + d_model * expansion_factor, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.conv_1(x) + x = x.permute((0, 2, 1)) + x = self.act(x) + x = self.dropout(x) + x = x.permute((0, 2, 1)) + x = self.conv_2(x) + x = x.permute((0, 2, 1)) + x = self.dropout(x) + x = 0.5 * x + return x + + +class ConformerConvModule(nn.Module): + def __init__( + self, + d_model: int, + expansion_factor: int = 2, + kernel_size: int = 7, + dropout: float = 0.1, + lrelu_slope: float = 0.3, + ): + """ + Convolution module for conformer. Starts with a gating machanism. + a pointwise convolution and a gated linear unit (GLU). This is followed + by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to help with training. it also contains an expansion factor to project the number of channels. + + Args: + d_model (int): The dimension of model. + expansion_factor (int): The factor by which to project the number of channels. + kernel_size (int): Size of kernels for convolution modules. + dropout (float): Probabilty of dropout. + lrelu_slope (float): The slope coefficient for leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the conv module. + + """ + super().__init__() + inner_dim = d_model * expansion_factor + self.ln_1 = nn.LayerNorm(d_model) + self.conv_1 = PointwiseConv1d(d_model, inner_dim * 2) + self.conv_act = GLUActivation(slope=lrelu_slope) + self.depthwise = DepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=kernel_size, + padding=calc_same_padding(kernel_size)[0], + ) + self.ln_2 = nn.GroupNorm(1, inner_dim) + self.activation = nn.LeakyReLU(lrelu_slope) + self.conv_2 = PointwiseConv1d(inner_dim, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln_1(x) + x = x.permute(0, 2, 1) + x = self.conv_1(x) + x = self.conv_act(x) + x = self.depthwise(x) + x = self.ln_2(x) + x = self.activation(x) + x = self.conv_2(x) + x = x.permute(0, 2, 1) + x = self.dropout(x) + return x + + +class ConformerMultiHeadedSelfAttention(nn.Module): + """ + Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, + the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention + module to generalize better on different input length and the resulting encoder is more robust to the variance of + the utterance length. Conformer use prenorm residual units with dropout which helps training + and regularizing deeper models. + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + dropout_p (float): probability of dropout + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. + """ + + def __init__(self, d_model: int, num_heads: int, dropout_p: float): + super().__init__() + self.attention = RelativeMultiHeadAttention(d_model=d_model, num_heads=num_heads) + self.dropout = nn.Dropout(p=dropout_p) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + encoding: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable + encoding = encoding[:, : key.shape[1]] + encoding = encoding.repeat(batch_size, 1, 1) + outputs, attn = self.attention(query, key, value, pos_embedding=encoding, mask=mask) + outputs = self.dropout(outputs) + return outputs, attn + + +class RelativeMultiHeadAttention(nn.Module): + """ + Multi-head attention with relative positional encoding. + This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + Inputs: query, key, value, pos_embedding, mask + - **query** (batch, time, dim): Tensor containing query vector + - **key** (batch, time, dim): Tensor containing key vector + - **value** (batch, time, dim): Tensor containing value vector + - **pos_embedding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs**: Tensor produces by relative multi head attention module. + """ + + def __init__( + self, + d_model: int = 512, + num_heads: int = 16, + ): + super().__init__() + assert d_model % num_heads == 0, "d_model % num_heads should be zero." + self.d_model = d_model + self.d_head = int(d_model / num_heads) + self.num_heads = num_heads + self.sqrt_dim = math.sqrt(d_model) + + self.query_proj = nn.Linear(d_model, d_model) + self.key_proj = nn.Linear(d_model, d_model, bias=False) + self.value_proj = nn.Linear(d_model, d_model, bias=False) + self.pos_proj = nn.Linear(d_model, d_model, bias=False) + + self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u_bias) + torch.nn.init.xavier_uniform_(self.v_bias) + self.out_proj = nn.Linear(d_model, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_embedding: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = query.shape[0] + query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) + key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) + u_bias = self.u_bias.expand_as(query) + v_bias = self.v_bias.expand_as(query) + a = (query + u_bias).transpose(1, 2) + content_score = a @ key.transpose(2, 3) + b = (query + v_bias).transpose(1, 2) + pos_score = b @ pos_embedding.permute(0, 2, 3, 1) + pos_score = self._relative_shift(pos_score) + + score = content_score + pos_score + score = score * (1.0 / self.sqrt_dim) + + score.masked_fill_(mask, -1e9) + + attn = F.softmax(score, -1) + + context = (attn @ value).transpose(1, 2) + context = context.contiguous().view(batch_size, -1, self.d_model) + + return self.out_proj(context), attn + + def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use + batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() + zeros = torch.zeros((batch_size, num_heads, seq_length1, 1), device=pos_score.device) + padded_pos_score = torch.cat([zeros, pos_score], dim=-1) + padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) + return pos_score + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + split_size = self.num_units // self.num_heads + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + # score = softmax(QK^T / (d_k ** 0.5)) + scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + return out diff --git a/TTS/tts/layers/delightful_tts/conv_layers.py b/TTS/tts/layers/delightful_tts/conv_layers.py new file mode 100644 index 0000000000..354a0336a1 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/conv_layers.py @@ -0,0 +1,670 @@ +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class ConvNorm(nn.Module): + """A 1-dimensional convolutional layer with optional weight normalization. + + This layer wraps a 1D convolutional layer from PyTorch and applies + optional weight normalization. The layer can be used in a similar way to + the convolutional layers in PyTorch's `torch.nn` module. + + Args: + in_channels (int): The number of channels in the input signal. + out_channels (int): The number of channels in the output signal. + kernel_size (int, optional): The size of the convolving kernel. + Defaults to 1. + stride (int, optional): The stride of the convolution. Defaults to 1. + padding (int, optional): Zero-padding added to both sides of the input. + If `None`, the padding will be calculated so that the output has + the same length as the input. Defaults to `None`. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. + w_init_gain (str, optional): The weight initialization function to use. + Can be either 'linear' or 'relu'. Defaults to 'linear'. + use_weight_norm (bool, optional): If `True`, apply weight normalization + to the convolutional weights. Defaults to `False`. + + Shapes: + - Input: :math:`[N, D, T]` + + - Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + use_weight_norm=False, + ): + super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + self.kernel_size = kernel_size + self.dilation = dilation + self.use_weight_norm = use_weight_norm + conv_fn = nn.Conv1d + self.conv = conv_fn( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) + if self.use_weight_norm: + self.conv = nn.utils.weight_norm(self.conv) + + def forward(self, signal, mask=None): + conv_signal = self.conv(signal) + if mask is not None: + # always re-zero output if mask is + # available to match zero-padding + conv_signal = conv_signal * mask + return conv_signal + + +class ConvLSTMLinear(nn.Module): + def __init__( + self, + in_dim, + out_dim, + n_layers=2, + n_channels=256, + kernel_size=3, + p_dropout=0.1, + lstm_type="bilstm", + use_linear=True, + ): + super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments + self.out_dim = out_dim + self.lstm_type = lstm_type + self.use_linear = use_linear + self.dropout = nn.Dropout(p=p_dropout) + + convolutions = [] + for i in range(n_layers): + conv_layer = ConvNorm( + in_dim if i == 0 else n_channels, + n_channels, + kernel_size=kernel_size, + stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ) + conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight") + convolutions.append(conv_layer) + + self.convolutions = nn.ModuleList(convolutions) + + if not self.use_linear: + n_channels = out_dim + + if self.lstm_type != "": + use_bilstm = False + lstm_channels = n_channels + if self.lstm_type == "bilstm": + use_bilstm = True + lstm_channels = int(n_channels // 2) + + self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) + lstm_norm_fn_pntr = nn.utils.spectral_norm + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") + if self.lstm_type == "bilstm": + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") + + if self.use_linear: + self.dense = nn.Linear(n_channels, out_dim) + + def run_padded_sequence(self, context, lens): + context_embedded = [] + for b_ind in range(context.size()[0]): # TODO: speed up + curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() + for conv in self.convolutions: + curr_context = self.dropout(F.relu(conv(curr_context))) + context_embedded.append(curr_context[0].transpose(0, 1)) + context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) + return context + + def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use + lens_sorted, ids_sorted = torch.sort(lens, descending=True) + unsort_ids = [0] * lens.size(0) + for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate + unsort_ids[ids_sorted[i]] = i + lens_sorted = lens_sorted.long().cpu() + + context = context[ids_sorted] + context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) + context = fn(context)[0] + context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] + + # map back to original indices + context = context[unsort_ids] + return context + + def forward(self, context, lens): + if context.size()[0] > 1: + context = self.run_padded_sequence(context, lens) + # to B, D, T + context = context.transpose(1, 2) + else: + for conv in self.convolutions: + context = self.dropout(F.relu(conv(context))) + + if self.lstm_type != "": + context = context.transpose(1, 2) + self.bilstm.flatten_parameters() + if lens is not None: + context = self.run_unsorted_inputs(self.bilstm, context, lens) + else: + context = self.bilstm(context)[0] + context = context.transpose(1, 2) + + x_hat = context + if self.use_linear: + x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) + + return x_hat + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class PointwiseConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class BSConv1d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv1d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class BSConv2d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv2d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class Conv1dGLU(nn.Module): + """From DeepVoice 3""" + + def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): + super().__init__() + self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) + self.embedding_proj = nn.Linear(embedding_dim, d_model) + self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) + self.softsign = torch.nn.Softsign() + + def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: + x = x.permute((0, 2, 1)) + residual = x + x = self.conv(x) + splitdim = 1 + a, b = x.split(x.size(splitdim) // 2, dim=splitdim) + embeddings = self.embedding_proj(embeddings).unsqueeze(2) + softsign = self.softsign(embeddings) + softsign = softsign.expand_as(a) + a = a + softsign + x = a * torch.sigmoid(b) + x = x + residual + x = x * self.sqrt + x = x.permute((0, 2, 1)) + return x + + +class ConvTransposed(nn.Module): + """ + A 1D convolutional transposed layer for PyTorch. + This layer applies a 1D convolutional transpose operation to its input tensor, + where the number of channels of the input tensor is the same as the number of channels of the output tensor. + + Attributes: + in_channels (int): The number of channels in the input tensor. + out_channels (int): The number of channels in the output tensor. + kernel_size (int): The size of the convolutional kernel. Default: 1. + padding (int): The number of padding elements to add to the input tensor. Default: 0. + conv (BSConv1d): The 1D convolutional transpose layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + padding: int = 0, + ): + super().__init__() + self.conv = BSConv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous().transpose(1, 2) + x = self.conv(x) + x = x.contiguous().transpose(1, 2) + return x + + +class DepthwiseConvModule(nn.Module): + def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): + super().__init__() + padding = calc_same_padding(kernel_size) + self.depthwise = nn.Conv1d( + dim, + dim * expansion, + kernel_size=kernel_size, + padding=padding[0], + groups=dim, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) + self.ln = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.depthwise(x) + x = self.act(x) + x = self.out(x) + x = x.permute((0, 2, 1)) + return x + + +class AddCoords(nn.Module): + def __init__(self, rank: int, with_r: bool = False): + super().__init__() + self.rank = rank + self.with_r = with_r + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.rank == 1: + batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable + xx_range = torch.arange(dim_x, dtype=torch.int32) + xx_channel = xx_range[None, None, :] + + xx_channel = xx_channel.float() / (dim_x - 1) + xx_channel = xx_channel * 2 - 1 + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) + + xx_channel = xx_channel.to(x.device) + out = torch.cat([x, xx_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 2: + batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) + + xx_range = torch.arange(dim_y, dtype=torch.int32) + yy_range = torch.arange(dim_x, dtype=torch.int32) + xx_range = xx_range[None, None, :, None] + yy_range = yy_range[None, None, :, None] + + xx_channel = torch.matmul(xx_range, xx_ones) + yy_channel = torch.matmul(yy_range, yy_ones) + + # transpose y + yy_channel = yy_channel.permute(0, 1, 3, 2) + + xx_channel = xx_channel.float() / (dim_y - 1) + yy_channel = yy_channel.float() / (dim_x - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) + yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + + out = torch.cat([x, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 3: + batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) + zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) + + xy_range = torch.arange(dim_y, dtype=torch.int32) + xy_range = xy_range[None, None, None, :, None] + + yz_range = torch.arange(dim_z, dtype=torch.int32) + yz_range = yz_range[None, None, None, :, None] + + zx_range = torch.arange(dim_x, dtype=torch.int32) + zx_range = zx_range[None, None, None, :, None] + + xy_channel = torch.matmul(xy_range, xx_ones) + xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) + + yz_channel = torch.matmul(yz_range, yy_ones) + yz_channel = yz_channel.permute(0, 1, 3, 4, 2) + yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) + + zx_channel = torch.matmul(zx_range, zz_ones) + zx_channel = zx_channel.permute(0, 1, 4, 2, 3) + zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + zz_channel = zz_channel.to(x.device) + out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) + + if self.with_r: + rr = torch.sqrt( + torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) + ) + out = torch.cat([out, rr], dim=1) + else: + raise NotImplementedError + + return out + + +class CoordConv1d(nn.modules.conv.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 1 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv1d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class CoordConv2d(nn.modules.conv.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 2 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv2d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + nn.utils.remove_weight_norm(self.convt_pre[1]) + for block in self.conv_blocks: + nn.utils.remove_weight_norm(block[1]) diff --git a/TTS/tts/layers/delightful_tts/encoders.py b/TTS/tts/layers/delightful_tts/encoders.py new file mode 100644 index 0000000000..0878f0677a --- /dev/null +++ b/TTS/tts/layers/delightful_tts/encoders.py @@ -0,0 +1,261 @@ +from typing import List, Tuple, Union + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention +from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d +from TTS.tts.layers.delightful_tts.networks import STL + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +class ReferenceEncoder(nn.Module): + """ + Referance encoder for utterance and phoneme prosody encoders. Reference encoder + made up of convolution and RNN layers. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for encoder layers. + ref_enc_size (int): Size of the kernel for the conv layers. + ref_enc_strides (List[int]): List of strides to use for conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, time, dim): Tensor produced by Reference Encoder. + """ + + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + ): + super().__init__() + + n_mel_channels = num_mels + self.n_mel_channels = n_mel_channels + K = len(ref_enc_filters) + filters = [self.n_mel_channels] + ref_enc_filters + strides = [1] + ref_enc_strides + # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf + convs = [ + CoordConv1d( + in_channels=filters[0], + out_channels=filters[0 + 1], + kernel_size=ref_enc_size, + stride=strides[0], + padding=ref_enc_size // 2, + with_r=True, + ) + ] + convs2 = [ + nn.Conv1d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=ref_enc_size, + stride=strides[i], + padding=ref_enc_size // 2, + ) + for i in range(1, K) + ] + convs.extend(convs2) + self.convs = nn.ModuleList(convs) + + self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) + + self.gru = nn.GRU( + input_size=ref_enc_filters[-1], + hidden_size=ref_enc_gru_size, + batch_first=True, + ) + + def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + inputs --- [N, n_mels, timesteps] + outputs --- [N, E//2] + """ + + mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) + x = x.masked_fill(mel_masks, 0) + for conv, norm in zip(self.convs, self.norms): + x = conv(x) + x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K] + x = norm(x) + + for _ in range(2): + mel_lens = stride_lens(mel_lens) + + mel_masks = get_mask_from_lengths(mel_lens) + + x = x.masked_fill(mel_masks.unsqueeze(1), 0) + x = x.permute((0, 2, 1)) + x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) + + self.gru.flatten_parameters() + x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2] + x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + + return x, memory, mel_masks + + def calculate_channels( # pylint: disable=no-self-use + self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int + ) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class UtteranceLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + bottleneck_size_u: int, + token_num: int, + ): + """ + Encoder to extract prosody from utterance. it is made up of a reference encoder + with a couple of linear layers and style token layer with dropout. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. + ref_enc_size (int): Size of the kernel for the ref encoder conv layers. + ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + dropout (float): Probability of dropout. + n_hidden (int): Size of hidden layers. + bottleneck_size_u (int): Size of the bottle neck layer. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. + """ + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_u + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) + self.stl = STL(n_hidden=n_hidden, token_num=token_num) + self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: + """ + Shapes: + mels: :math: `[B, C, T]` + mel_lens: :math: `[B]` + + out --- [N, seq_len, E] + """ + _, embedded_prosody, _ = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + # Style Token + out = self.encoder_bottleneck(self.stl(embedded_prosody)) + out = self.dropout(out) + + out = out.view((-1, 1, out.shape[3])) + return out + + +class PhonemeLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + n_heads: int, + bottleneck_size_p: int, + ): + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_p + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) + self.attention = ConformerMultiHeadedSelfAttention( + d_model=n_hidden, + num_heads=n_heads, + dropout_p=dropout, + ) + self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) + + def forward( + self, + x: torch.Tensor, + src_mask: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + x --- [N, seq_len, encoder_embedding_dim] + mels --- [N, Ty/r, n_mels*r], r=1 + out --- [N, seq_len, bottleneck_size] + attn --- [N, seq_len, ref_len], Ty/r = ref_len + """ + embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) + x, _ = self.attention( + query=x, + key=embedded_prosody, + value=embedded_prosody, + mask=attn_mask, + encoding=encoding, + ) + x = self.encoder_bottleneck(x) + x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) + return x diff --git a/TTS/tts/layers/delightful_tts/energy_adaptor.py b/TTS/tts/layers/delightful_tts/energy_adaptor.py new file mode 100644 index 0000000000..ea0d1e4721 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/energy_adaptor.py @@ -0,0 +1,82 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class EnergyAdaptor(nn.Module): # pylint: disable=abstract-method + """Variance Adaptor with an added 1D conv layer. Used to + get energy embeddings. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels. + kernel_size (int): Size the kernel for the conv layers. + dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + emb_kernel_size (int): Size the kernel for the pitch embedding. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the energy target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **energy prediction** (batch, 1, time1): Tensor produced by energy predictor + - **energy embedding** (batch, channels, time1): Tensor produced energy adaptor + - **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + + """ + + def __init__( + self, + channels_in: int, + channels_hidden: int, + channels_out: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + emb_kernel_size: int, + ): + super().__init__() + self.energy_predictor = VariancePredictor( + channels_in=channels_in, + channels=channels_hidden, + channels_out=channels_out, + kernel_size=kernel_size, + p_dropout=dropout, + lrelu_slope=lrelu_slope, + ) + self.energy_emb = nn.Conv1d( + 1, + channels_hidden, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_energy_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + avg_energy_target = average_over_durations(target, dr) + energy_emb = self.energy_emb(avg_energy_target) + return energy_pred, avg_energy_target, energy_emb + + def get_energy_embedding(self, x: torch.Tensor, mask: torch.Tensor, energy_transform: Callable) -> torch.Tensor: + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + if energy_transform is not None: + energy_pred = energy_transform(energy_pred, (~mask).sum(dim=(1, 2)), self.pitch_mean, self.pitch_std) + energy_emb_pred = self.energy_emb(energy_pred) + return energy_emb_pred, energy_pred diff --git a/TTS/tts/layers/delightful_tts/kernel_predictor.py b/TTS/tts/layers/delightful_tts/kernel_predictor.py new file mode 100644 index 0000000000..19dfd57e7b --- /dev/null +++ b/TTS/tts/layers/delightful_tts/kernel_predictor.py @@ -0,0 +1,125 @@ +import torch.nn as nn # pylint: disable=consider-using-from-import + + +class KernelPredictor(nn.Module): + """Kernel predictor for the location-variable convolutions + + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + + """ + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv[0]) + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + for block in self.residual_convs: + nn.utils.remove_weight_norm(block[1]) + nn.utils.remove_weight_norm(block[3]) diff --git a/TTS/tts/layers/delightful_tts/networks.py b/TTS/tts/layers/delightful_tts/networks.py new file mode 100644 index 0000000000..4305022f18 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/networks.py @@ -0,0 +1,219 @@ +import math +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import ConvNorm + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + # Kaiming initialization + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +def positional_encoding(d_model: int, length: int, device: torch.device) -> torch.Tensor: + pe = torch.zeros(length, d_model, device=device) + position = torch.arange(0, length, dtype=torch.float, device=device).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + return pe + + +class BottleneckLayer(nn.Module): + """ + Bottleneck layer for reducing the dimensionality of a tensor. + + Args: + in_dim: The number of input dimensions. + reduction_factor: The factor by which to reduce the number of dimensions. + norm: The normalization method to use. Can be "weightnorm" or "instancenorm". + non_linearity: The non-linearity to use. Can be "relu" or "leakyrelu". + kernel_size: The size of the convolutional kernel. + use_partial_padding: Whether to use partial padding with the convolutional kernel. + + Shape: + - Input: :math:`[N, in_dim]` where `N` is the batch size and `in_dim` is the number of input dimensions. + + - Output: :math:`[N, out_dim]` where `out_dim` is the number of output dimensions. + """ + + def __init__( + self, + in_dim, + reduction_factor, + norm="weightnorm", + non_linearity="relu", + kernel_size=3, + use_partial_padding=False, # pylint: disable=unused-argument + ): + super(BottleneckLayer, self).__init__() # pylint: disable=super-with-arguments + + self.reduction_factor = reduction_factor + reduced_dim = int(in_dim / reduction_factor) + self.out_dim = reduced_dim + if self.reduction_factor > 1: + fn = ConvNorm(in_dim, reduced_dim, kernel_size=kernel_size, use_weight_norm=(norm == "weightnorm")) + if norm == "instancenorm": + fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True)) + + self.projection_fn = fn + self.non_linearity = nn.ReLU() + if non_linearity == "leakyrelu": + self.non_linearity = nn.LeakyReLU() + + def forward(self, x): + if self.reduction_factor > 1: + x = self.projection_fn(x) + x = self.non_linearity(x) + return x + + +class GLUActivation(nn.Module): + """Class that implements the Gated Linear Unit (GLU) activation function. + + The GLU activation function is a variant of the Leaky ReLU activation function, + where the output of the activation function is gated by an input tensor. + + """ + + def __init__(self, slope: float): + super().__init__() + self.lrelu = nn.LeakyReLU(slope) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, gate = x.chunk(2, dim=1) + x = out * self.lrelu(gate) + return x + + +class StyleEmbedAttention(nn.Module): + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor: + values = self.W_value(key_soft) + split_size = self.num_units // self.num_heads + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) + + out_soft = scores_soft = None + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key_soft) # [N, T_k, num_units] + + # [h, N, T_q, num_units/h] + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k ** 0.5)) + scores_soft = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores_soft = scores_soft / (self.key_dim**0.5) + scores_soft = F.softmax(scores_soft, dim=3) + + # out = score * V + # [h, N, T_q, num_units/h] + out_soft = torch.matmul(scores_soft, values) + out_soft = torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out_soft # , scores_soft + + +class EmbeddingPadded(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + super().__init__() + padding_mult = torch.ones((num_embeddings, 1), dtype=torch.int64) + padding_mult[padding_idx] = 0 + self.register_buffer("padding_mult", padding_mult) + self.embeddings = nn.parameter.Parameter(initialize_embeddings((num_embeddings, embedding_dim))) + + def forward(self, idx: torch.Tensor) -> torch.Tensor: + embeddings_zeroed = self.embeddings * self.padding_mult + x = F.embedding(idx, embeddings_zeroed) + return x + + +class EmbeddingProjBlock(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x + for layer in self.layers: + x = layer(x) + x = x + res + return x + + +class LinearNorm(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + nn.init.xavier_uniform_(self.linear.weight) + if bias: + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + return x + + +class STL(nn.Module): + """ + A PyTorch module for the Style Token Layer (STL) as described in + "A Style-Based Generator Architecture for Generative Adversarial Networks" + (https://arxiv.org/abs/1812.04948) + + The STL applies a multi-headed attention mechanism over the learned style tokens, + using the text input as the query and the style tokens as the keys and values. + The output of the attention mechanism is used as the text's style embedding. + + Args: + token_num (int): The number of style tokens. + n_hidden (int): Number of hidden dimensions. + """ + + def __init__(self, n_hidden: int, token_num: int): + super(STL, self).__init__() # pylint: disable=super-with-arguments + + num_heads = 1 + E = n_hidden + self.token_num = token_num + self.embed = nn.Parameter(torch.FloatTensor(self.token_num, E // num_heads)) + d_q = E // 2 + d_k = E // num_heads + self.attention = StyleEmbedAttention(query_dim=d_q, key_dim=d_k, num_units=E, num_heads=num_heads) + + torch.nn.init.normal_(self.embed, mean=0, std=0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N = x.size(0) + query = x.unsqueeze(1) # [N, 1, E//2] + + keys_soft = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] + + # Weighted sum + emotion_embed_soft = self.attention(query, keys_soft) + + return emotion_embed_soft diff --git a/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py b/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py new file mode 100644 index 0000000000..28418f7163 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class PhonemeProsodyPredictor(nn.Module): + """Non-parallel Prosody Predictor inspired by: https://arxiv.org/pdf/2102.00851.pdf + It consists of 2 layers of 1D convolutions each followed by a relu activation, layer norm + and dropout, then finally a linear layer. + + Args: + hidden_size (int): Size of hidden channels. + kernel_size (int): Kernel size for the conv layers. + dropout: (float): Probability of dropout. + bottleneck_size (int): bottleneck size for last linear layer. + lrelu_slope (float): Slope of the leaky relu. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + dropout: float, + bottleneck_size: int, + lrelu_slope: float, + ): + super().__init__() + self.d_model = hidden_size + self.layers = nn.ModuleList( + [ + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ] + ) + self.predictor_bottleneck = nn.Linear(self.d_model, bottleneck_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, D]` + mask: :math: `[B, T]` + """ + mask = mask.unsqueeze(2) + for layer in self.layers: + x = layer(x) + x = x.masked_fill(mask, 0.0) + x = self.predictor_bottleneck(x) + return x diff --git a/TTS/tts/layers/delightful_tts/pitch_adaptor.py b/TTS/tts/layers/delightful_tts/pitch_adaptor.py new file mode 100644 index 0000000000..9031369e0f --- /dev/null +++ b/TTS/tts/layers/delightful_tts/pitch_adaptor.py @@ -0,0 +1,88 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class PitchAdaptor(nn.Module): # pylint: disable=abstract-method + """Module to get pitch embeddings via pitch predictor + + Args: + n_input (int): Number of pitch predictor input channels. + n_hidden (int): Number of pitch predictor hidden channels. + n_out (int): Number of pitch predictor out channels. + kernel size (int): Size of the kernel for conv layers. + emb_kernel_size (int): Size the kernel for the pitch embedding. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the pitch target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **pitch prediction** (batch, 1, time1): Tensor produced by pitch predictor + - **pitch embedding** (batch, channels, time1): Tensor produced pitch pitch adaptor + - **average pitch target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + """ + + def __init__( + self, + n_input: int, + n_hidden: int, + n_out: int, + kernel_size: int, + emb_kernel_size: int, + p_dropout: float, + lrelu_slope: float, + ): + super().__init__() + self.pitch_predictor = VariancePredictor( + channels_in=n_input, + channels=n_hidden, + channels_out=n_out, + kernel_size=kernel_size, + p_dropout=p_dropout, + lrelu_slope=lrelu_slope, + ) + self.pitch_emb = nn.Conv1d( + 1, + n_input, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_pitch_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + pitch_pred = self.pitch_predictor(x, mask) # [B, T_src, C_hidden], [B, T_src] --> [B, T_src] + pitch_pred.unsqueeze_(1) # --> [B, 1, T_src] + avg_pitch_target = average_over_durations(target, dr) # [B, 1, T_mel], [B, T_src] --> [B, 1, T_src] + pitch_emb = self.pitch_emb(avg_pitch_target) # [B, 1, T_src] --> [B, C_hidden, T_src] + return pitch_pred, avg_pitch_target, pitch_emb + + def get_pitch_embedding( + self, + x: torch.Tensor, + mask: torch.Tensor, + pitch_transform: Callable, + pitch_mean: torch.Tensor, + pitch_std: torch.Tensor, + ) -> torch.Tensor: + pitch_pred = self.pitch_predictor(x, mask) + if pitch_transform is not None: + pitch_pred = pitch_transform(pitch_pred, (~mask).sum(), pitch_mean, pitch_std) + pitch_pred.unsqueeze_(1) + pitch_emb_pred = self.pitch_emb(pitch_pred) + return pitch_emb_pred, pitch_pred diff --git a/TTS/tts/layers/delightful_tts/variance_predictor.py b/TTS/tts/layers/delightful_tts/variance_predictor.py new file mode 100644 index 0000000000..68303a1bd1 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/variance_predictor.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class VariancePredictor(nn.Module): + """ + Network is 2-layer 1D convolutions with leaky relu activation and then + followed by layer normalization then a dropout layer and finally an + extra linear layer to project the hidden states into the output sequence. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels for the last linear layer. + kernel_size (int): Size the kernel for the conv layers. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, time): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time): Tensor produced by last linear layer. + """ + + def __init__( + self, channels_in: int, channels: int, channels_out: int, kernel_size: int, p_dropout: float, lrelu_slope: float + ): + super().__init__() + + self.layers = nn.ModuleList( + [ + ConvTransposed( + channels_in, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ConvTransposed( + channels, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ] + ) + + self.linear_layer = nn.Linear(channels, channels_out) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T_src, C]` + mask: :math: `[B, T_src]` + """ + for layer in self.layers: + x = layer(x) + x = self.linear_layer(x) + x = x.squeeze(-1) + x = x.masked_fill(mask, 0.0) + return x diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py index eef4c4b66d..baa6f0e9c4 100644 --- a/TTS/tts/layers/generic/aligner.py +++ b/TTS/tts/layers/generic/aligner.py @@ -57,6 +57,15 @@ def __init__( nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), ) + self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear")) + def forward( self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None ) -> Tuple[torch.tensor, torch.tensor]: @@ -75,7 +84,9 @@ def forward( attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) if attn_prior is not None: attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + if mask is not None: attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) return attn, attn_logp diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index 260e504a1a..ee3b820637 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -214,6 +214,7 @@ def synthesize( as latents used at inference. """ + speaker_id = "random" if speaker_id is None else speaker_id voice_dirs = self._set_voice_dirs(voice_dirs) history_prompt = load_voice(self, speaker_id, voice_dirs) outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index bda0fc9ea8..7871cc38c3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -439,3 +439,21 @@ def on_init_start(self, trainer): trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) print(f" > `language_ids.json` is saved to {output_path}.") print(" > `language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py new file mode 100644 index 0000000000..a832e23b54 --- /dev/null +++ b/TTS/tts/models/delightful_tts.py @@ -0,0 +1,1774 @@ +import os +from dataclasses import dataclass, field +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel +from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy +from TTS.utils.audio.processor import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: + out_list = torch.jit.annotate(List[torch.Tensor], []) + for batch in input_ele: + if len(batch.shape) == 1: + one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) + else: + one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +# pylint: disable=redefined-outer-name +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path: str): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load( + file_path, + ) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window # pylint: disable=global-statement + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + return spec + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def wav_to_energy(y, n_fft, hop_length, win_length, center=False): + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return torch.norm(spec, dim=1, keepdim=True) + + +def name_mel_basis(spec, n_fft, fmax): + n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + return n_fft_len + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis # pylint: disable=global-statement + mel_basis_key = name_mel_basis(spec, n_fft, fmax) + # pylint: disable=too-many-function-args + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[mel_basis_key], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T_y]` + + Return Shapes: + - spec : :math:`[B,C,T_spec]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + mel_basis_key = name_mel_basis(y, n_fft, fmax) + wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn( + sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) # pylint: disable=too-many-function-args + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create balancer weight for torch WeightedSampler""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid slow computing of pitches""" + + def __init__( + self, + ap, + samples: Union[List[List], List[Dict]], + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + super().__init__( + samples=samples, + ap=ap, + verbose=verbose, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + def _compute_and_save_pitch(self, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0( + x=wav.numpy()[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + pitch_fmin=self.ap.pitch_fmin, + win_length=self.ap.win_length, + ) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % self.ap.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file, audio_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + self.ap = kwargs["ap"] + + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + ap=self.ap, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + if self.attn_prior_cache_path is not None: + os.makedirs(self.attn_prior_cache_path, exist_ok=True) + + def __getitem__(self, idx): + item = self.samples[idx] + + rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") + rel_wav_path = str(rel_wav_path).replace("/", "_") + + raw_text = item["text"] + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + try: + token_ids = self.get_token_ids(idx, item["text"]) + except: + print(idx, item) + # pylint: disable=raise-missing-from + raise OSError + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + attn_prior = None + if self.attn_prior_cache_path is not None: + attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "attn_prior": attn_prior, + "audio_unique_name": item["audio_unique_name"], + } + + def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): + """Load or compute and save the attention prior.""" + attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") + # pylint: disable=no-else-return + if os.path.exists(attn_prior_file): + return np.load(attn_prior_file) + else: + token_len = len(token_ids) + mel_len = wav.shape[1] // self.ap.hop_length + attn_prior = compute_attn_prior(token_len, mel_len) + np.save(attn_prior_file, attn_prior) + return attn_prior + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - attn_prior: :math:`[[T_token, T_mel]]` + """ + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_padded = None + if self.compute_f0: + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + pitch_padded = pitch_padded.zero_() + self.pad_id + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + + for i in range(B): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + if self.compute_f0: + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_unique_names": batch["audio_unique_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, + } + + +############################## +# CONFIG DEFINITIONS +############################## + + +@dataclass +class VocoderConfig(Coqpit): + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + use_spectral_norm_discriminator: bool = False + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + pretrained_model_path: Optional[str] = None + + +@dataclass +class DelightfulTtsAudioConfig(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 100 + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + do_trim_silence: bool = True + trim_db: int = 45 + do_rms_norm: bool = False + db_level: float = None + power: float = 1.5 + griffin_lim_iters: int = 60 + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + min_level_db: int = -100 + max_norm: float = 4.0 + + +@dataclass +class DelightfulTtsArgs(Coqpit): + num_chars: int = 100 + spec_segment_size: int = 32 + n_hidden_conformer_encoder: int = 512 + n_layers_conformer_encoder: int = 6 + n_heads_conformer_encoder: int = 8 + dropout_conformer_encoder: float = 0.1 + kernel_size_conv_mod_conformer_encoder: int = 7 + kernel_size_depthwise_conformer_encoder: int = 7 + lrelu_slope: float = 0.3 + n_hidden_conformer_decoder: int = 512 + n_layers_conformer_decoder: int = 6 + n_heads_conformer_decoder: int = 8 + dropout_conformer_decoder: float = 0.1 + kernel_size_conv_mod_conformer_decoder: int = 11 + kernel_size_depthwise_conformer_decoder: int = 11 + bottleneck_size_p_reference_encoder: int = 4 + bottleneck_size_u_reference_encoder: int = 512 + ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] + ref_enc_size_reference_encoder: int = 3 + ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] + ref_enc_pad_reference_encoder = [1, 1] + ref_enc_gru_size_reference_encoder: int = 32 + ref_attention_dropout_reference_encoder: float = 0.2 + token_num_reference_encoder: int = 32 + predictor_kernel_size_reference_encoder: int = 5 + n_hidden_variance_adaptor: int = 512 + kernel_size_variance_adaptor: int = 5 + dropout_variance_adaptor: float = 0.5 + n_bins_variance_adaptor: int = 256 + emb_kernel_size_variance_adaptor: int = 3 + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 384 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + freeze_vocoder: bool = False + freeze_text_encoder: bool = False + freeze_duration_predictor: bool = False + freeze_pitch_predictor: bool = False + freeze_energy_predictor: bool = False + freeze_basis_vectors_predictor: bool = False + freeze_decoder: bool = False + length_scale: float = 1.0 + + +############################## +# MODEL DEFINITION +############################## +class DelightfulTTS(BaseTTSE2E): + """ + Paper:: + https://arxiv.org/pdf/2110.12612.pdf + + Paper Abstract:: + This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. + The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: + The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems + with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves + the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and + propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training + efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and + implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and + inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) + For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. + Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test + and 4.35 in SMOS test, which indicates the effectiveness of our proposed system + + + Model training:: + text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg + spec --------^ + + Examples: + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + self.ap = ap + + self._set_model_args(config) + self.init_multispeaker(config) + self.binary_loss_weight = None + + self.args.out_channels = self.config.audio.num_mels + self.args.num_mels = self.config.audio.num_mels + self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) + + self.waveform_decoder = HifiganGenerator( + self.config.audio.num_mels, + 1, + self.config.vocoder.resblock_type_decoder, + self.config.vocoder.resblock_dilation_sizes_decoder, + self.config.vocoder.resblock_kernel_sizes_decoder, + self.config.vocoder.upsample_kernel_sizes_decoder, + self.config.vocoder.upsample_initial_channel_decoder, + self.config.vocoder.upsample_rates_decoder, + inference_padding=0, + # cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.config.init_discriminator: + self.disc = VitsDiscriminator( + use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, + periods=self.config.vocoder.periods_discriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def energy_scaler(self): + return self.acoustic_model.energy_scaler + + @property + def length_scale(self): + return self.acoustic_model.length_scale + + @length_scale.setter + def length_scale(self, value): + self.acoustic_model.length_scale = value + + @property + def pitch_mean(self): + return self.acoustic_model.pitch_mean + + @pitch_mean.setter + def pitch_mean(self, value): + self.acoustic_model.pitch_mean = value + + @property + def pitch_std(self): + return self.acoustic_model.pitch_std + + @pitch_mean.setter + def pitch_std(self, value): # pylint: disable=function-redefined + self.acoustic_model.pitch_std = value + + @property + def mel_basis(self): + return build_mel_basis( + sample_rate=self.ap.sample_rate, + fft_size=self.ap.fft_size, + num_mels=self.ap.num_mels, + mel_fmax=self.ap.mel_fmax, + mel_fmin=self.ap.mel_fmin, + ) # pylint: disable=function-redefined + + def init_for_training(self) -> None: + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + self.config.steps_to_start_discriminator <= 0 + ) # pylint: disable=attribute-defined-outside-init + self.update_energy_scaler = True # pylint: disable=attribute-defined-outside-init + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + self.args.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.args.embedded_speaker_dim = self.args.speaker_embedding_channels + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + self.args.embedded_speaker_dim = self.args.d_vector_dim + + def _freeze_layers(self): + if self.args.freeze_vocoder: + for param in self.vocoder.paramseters(): + param.requires_grad = False + + if self.args.freeze_text_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_duration_predictor: + for param in self.durarion_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_pitch_predictor: + for param in self.pitch_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_energy_predictor: + for param in self.energy_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_decoder: + for param in self.decoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + spec_lengths: torch.LongTensor, + spec: torch.FloatTensor, + waveform: torch.FloatTensor, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + attn_priors: torch.FloatTensor = None, + d_vectors: torch.FloatTensor = None, + speaker_idx: torch.LongTensor = None, + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. + spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + waveform (torch.FloatTensor): Waveform. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - spec_lengths: :math:`[B]` + - spec: :math:`[B, T_max2, C_spec]` + - waveform: :math:`[B, 1, T_max2 * hop_length]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T_max2]` + - energy: :math:`[B, 1, T_max2]` + """ + encoder_outputs = self.acoustic_model( + tokens=x, + src_lens=x_lengths, + mel_lens=spec_lengths, + mels=spec, + pitches=pitch, + energies=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_idx, + ) + + # use mel-spec from the decoder + vocoder_input = encoder_outputs["model_outputs"] # [B, T_max2, C_mel] + + vocoder_input_slices, slice_ids = rand_segments( + x=vocoder_input.transpose(1, 2), + x_lengths=spec_lengths, + segment_size=self.args.spec_segment_size, + let_short_samples=True, + pad_short=True, + ) + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) + wav_seg = segment( + waveform, + slice_ids * self.ap.hop_length, + self.args.spec_segment_size * self.ap.hop_length, + pad_short=True, + ) + model_outputs = {**encoder_outputs} + model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] + model_outputs["model_outputs"] = vocoder_output + model_outputs["waveform_seg"] = wav_seg + model_outputs["slice_ids"] = slice_ids + return model_outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None + ): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + pitch_transform=pitch_transform, + energy_transform=energy_transform, + p_control=None, + d_control=None, + ) + vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) # [B, T_max2, C_mel] -> [B, C_mel, T_max2] + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) + model_outputs = {**encoder_outputs} + model_outputs["model_outputs"] = vocoder_output + return model_outputs + + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + ) + model_outputs = {**encoder_outputs} + return model_outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + if optimizer_idx == 0: + tokens = batch["text_input"] + token_lenghts = batch["text_lengths"] + mel = batch["mel_input"] + mel_lens = batch["mel_lengths"] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_priors = batch["attn_priors"] + energy = batch["energy"] + + # generator pass + outputs = self.forward( + x=tokens, + x_lengths=token_lenghts, + spec_lengths=mel_lens, + spec=mel, + waveform=waveform, + pitch=pitch, + energy=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_ids, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + if self.train_disc: + # compute scores and features + scores_d_fake, _, scores_d_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_fake=scores_d_fake, + scores_disc_real=scores_d_real, + ) + return outputs, loss_dict + return None, None + + if optimizer_idx == 1: + mel = batch["mel_input"] + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True + ) + + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + ) + + scores_d_fake = None + feats_d_fake = None + feats_d_real = None + + if self.train_disc: + # compute discriminator scores and features + scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=True): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), + mel_target=batch["mel_input"], + mel_lens=batch["mel_lengths"], + dur_output=self.model_outputs_cache["dr_log_pred"], + dur_target=self.model_outputs_cache["dr_log_target"].detach(), + pitch_output=self.model_outputs_cache["pitch_pred"], + pitch_target=self.model_outputs_cache["pitch_target"], + energy_output=self.model_outputs_cache["energy_pred"], + energy_target=self.model_outputs_cache["energy_target"], + src_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], + p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], + p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], + u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], + u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], + aligner_logprob=self.model_outputs_cache["aligner_logprob"], + aligner_hard=self.model_outputs_cache["aligner_mas"], + aligner_soft=self.model_outputs_cache["aligner_soft"], + binary_loss_weight=self.binary_loss_weight, + feats_fake=feats_d_fake, + feats_real=feats_d_real, + scores_fake=scores_d_fake, + spec_slice=mel_slice, + spec_slice_hat=mel_slice_hat, + skip_disc=not self.train_disc, + ) + + loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() + loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() + loss_dict["avg_text_batch_occupancy"] = ( + batch["text_lengths"].float() / batch["text_lengths"].float().max() + ).mean() + loss_dict["avg_mel_batch_occupancy"] = ( + batch["mel_lengths"].float() / batch["mel_lengths"].float().max() + ).mean() + + return self.model_outputs_cache, loss_dict + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} + + # encoder outputs + model_outputs = outputs[1]["acoustic_model_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio + ) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + + # vocoder outputs + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) + figures.update(vocoder_figures) + + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios[f"{name_prefix}/vocoder_audio"] = sample_voice + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use, unused-argument + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previous training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav = None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector = None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.ids[speaker_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + + def plot_outputs(self, text, wav, alignment, outputs): + figures = {} + pitch_avg_pred = outputs["pitch"].cpu() + energy_avg_pred = outputs["energy"].cpu() + spec = wav_to_mel( + y=torch.from_numpy(wav[None, :]), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + )[0].transpose(0, 1) + pitch = compute_f0( + x=wav[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + ) + input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) + input_text = input_text.replace("", "_") + durations = outputs["durations"] + pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) # [1, 1, n_frames] + pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean + figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) + figures["spectrogram"] = plot_spectrogram(spec) + figures["pitch_from_wav"] = plot_pitch(pitch, spec) + figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) + figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) + figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) + return figures + + def synthesize( + self, + text: str, + speaker_id: str = None, + d_vector: torch.tensor = None, + pitch_transform=None, + **kwargs, + ): # pylint: disable=unused-argument + # TODO: add cloning support with ref_waveform + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + + # set speaker inputs + _speaker_id = None + if speaker_id is not None and (self.args.use_speaker_embedding or self.args.use_d_vector_file): + if isinstance(speaker_id, str) and self.args.use_speaker_embedding: + # get the speaker id for the speaker embedding layer + _speaker_id = self.speaker_manager.name_to_id[speaker_id] + _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) + else: + # get the average d_vector for the speaker + d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) + + if d_vector is not None and self.args.use_d_vector_file: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference( + text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, + pitch_transform=pitch_transform, + # energy_transform=energy_transform + ) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, d_vector): + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder( + x=text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, + ) + + # collect outputs + S = outputs["model_outputs"].cpu().numpy()[0].T + S = db_to_amp_numpy(x=S, gain=1, base=None) + wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + outputs = self.synthesize( + aux_inputs["text"], + config=self.config, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + # speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]] + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + + ac = self.ap + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + + # format attn priors as we now the max mel length + # TODO: fix 1 diff b/w mel_lengths and attn_priors + + if self.config.use_attn_priors: + attn_priors_np = batch["attn_priors"] + + batch["attn_priors"] = torch.zeros( + batch["mel_input"].shape[0], + batch["mel_lengths"].max(), + batch["text_lengths"].max(), + device=batch["mel_input"].device, + ) + + for i in range(batch["mel_input"].shape[0]): + batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( + attn_priors_np[i] + ) + + batch["energy"] = None + batch["energy"] = wav_to_energy( # [B, 1, T_max2] + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + center=False, + ) + batch["energy"] = self.energy_scaler(batch["energy"]) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + ap=self.ap, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences ascendingly by length + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + + # get pitch mean and std + self.pitch_mean = dataset.f0_dataset.mean + self.pitch_std = dataset.f0_dataset.std + return loader + + def get_criterion(self): + return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + optimizer_disc = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc + ) + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer_gen = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer_disc, optimizer_gen] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + + def on_epoch_end(self, trainer): # pylint: disable=unused-argument + # stop updating mean and var + # TODO: do the same for F0 + self.energy_scaler.eval() + + @staticmethod + def init_from_config( + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + ): # pylint: disable=unused-argument + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) + ap = AudioProcessor.init_from_config(config=config) + return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) + + def load_checkpoint(self, config, checkpoint_path, eval=False): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_names + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) # pylint: disable=too-many-function-args + save_state["pitch_mean"] = self.pitch_mean + save_state["pitch_std"] = self.pitch_std + torch.save(save_state, checkpoint_path) + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + trainer.total_steps_done >= self.config.steps_to_start_discriminator + ) + + +class DelightfulTTSLoss(nn.Module): + def __init__(self, config): + super().__init__() + + self.mse_loss = nn.MSELoss() + self.mae_loss = nn.L1Loss() + self.forward_sum_loss = ForwardSumLoss() + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) + + self.mel_loss_alpha = config.mel_loss_alpha + self.aligner_loss_alpha = config.aligner_loss_alpha + self.pitch_loss_alpha = config.pitch_loss_alpha + self.energy_loss_alpha = config.energy_loss_alpha + self.u_prosody_loss_alpha = config.u_prosody_loss_alpha + self.p_prosody_loss_alpha = config.p_prosody_loss_alpha + self.dur_loss_alpha = config.dur_loss_alpha + self.char_dur_loss_alpha = config.char_dur_loss_alpha + self.binary_alignment_loss_alpha = config.binary_align_loss_alpha + + self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + mel_output, + mel_target, + mel_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + src_lens, + waveform, + waveform_hat, + p_prosody_ref, + p_prosody_pred, + u_prosody_ref, + u_prosody_pred, + aligner_logprob, + aligner_hard, + aligner_soft, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + skip_disc=False, + ): + """ + Shapes: + - mel_output: :math:`(B, C_mel, T_mel)` + - mel_target: :math:`(B, C_mel, T_mel)` + - mel_lens: :math:`(B)` + - dur_output: :math:`(B, T_src)` + - dur_target: :math:`(B, T_src)` + - pitch_output: :math:`(B, 1, T_src)` + - pitch_target: :math:`(B, 1, T_src)` + - energy_output: :math:`(B, 1, T_src)` + - energy_target: :math:`(B, 1, T_src)` + - src_lens: :math:`(B)` + - waveform: :math:`(B, 1, T_wav)` + - waveform_hat: :math:`(B, 1, T_wav)` + - p_prosody_ref: :math:`(B, T_src, 4)` + - p_prosody_pred: :math:`(B, T_src, 4)` + - u_prosody_ref: :math:`(B, 1, 256) + - u_prosody_pred: :math:`(B, 1, 256) + - aligner_logprob: :math:`(B, 1, T_mel, T_src)` + - aligner_hard: :math:`(B, T_mel, T_src)` + - aligner_soft: :math:`(B, T_mel, T_src)` + - spec_slice: :math:`(B, C_mel, T_mel)` + - spec_slice_hat: :math:`(B, C_mel, T_mel)` + """ + loss_dict = {} + src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) + mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) + + dur_target.requires_grad = False + mel_target.requires_grad = False + pitch_target.requires_grad = False + + masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) + mel_targets = mel_target.masked_select(mel_mask[:, None]) + mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) + + p_prosody_ref = p_prosody_ref.detach() + p_prosody_loss = 0.5 * self.mae_loss( + p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), + p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), + ) + + u_prosody_ref = u_prosody_ref.detach() + u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) + + duration_loss = self.mse_loss(dur_output, dur_target) + + pitch_output = pitch_output.masked_select(src_mask[:, None]) + pitch_target = pitch_target.masked_select(src_mask[:, None]) + pitch_loss = self.mse_loss(pitch_output, pitch_target) + + energy_output = energy_output.masked_select(src_mask[:, None]) + energy_target = energy_target.masked_select(src_mask[:, None]) + energy_loss = self.mse_loss(energy_output, energy_target) + + forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) + + total_loss = ( + (mel_loss * self.mel_loss_alpha) + + (duration_loss * self.dur_loss_alpha) + + (u_prosody_loss * self.u_prosody_loss_alpha) + + (p_prosody_loss * self.p_prosody_loss_alpha) + + (pitch_loss * self.pitch_loss_alpha) + + (energy_loss * self.energy_loss_alpha) + + (forward_sum_loss * self.aligner_loss_alpha) + ) + + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + if binary_loss_weight: + loss_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss + loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss + loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss + loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss + loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss + loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + loss_dict["loss"] = total_loss + + # vocoder losses + if not skip_disc: + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen + + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg + return loss_dict diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index af7ea583a0..d998825628 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -513,9 +513,13 @@ def synthesize(self, text, config, speaker_id="random", voice_dirs=None, **kwarg as latents used at inference. """ + + speaker_id = "random" if speaker_id is None else speaker_id + if voice_dirs is not None: voice_dirs = [voice_dirs] voice_samples, conditioning_latents = load_voice(speaker_id, voice_dirs) + else: voice_samples, conditioning_latents = load_voice(speaker_id) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c6d1ec2c06..7b37201f84 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -1,5 +1,6 @@ import numpy as np import torch +from scipy.stats import betabinom from torch.nn import functional as F try: @@ -233,3 +234,25 @@ def maximum_path_numpy(value, mask, max_neg_val=None): path = path * mask.astype(np.float32) path = torch.from_numpy(path).to(device=device, dtype=dtype) return path + + +def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): + P, M = phoneme_count, mel_count + x = np.arange(0, P) + mel_text_probs = [] + for i in range(1, M + 1): + a, b = scaling_factor * i, scaling_factor * (M + 1 - i) + rv = betabinom(P, a, b) + mel_i_prob = rv.pmf(x) + mel_text_probs.append(mel_i_prob) + return np.array(mel_text_probs) + + +def compute_attn_prior(x_len, y_len, scaling_factor=1.0): + """Compute attention priors for the alignment network.""" + attn_prior = beta_binomial_prior_distribution( + x_len, + y_len, + scaling_factor, + ) + return attn_prior # [y_len, x_len] diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 760738467e..bc0e231df0 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -361,12 +361,12 @@ def tts( if not reference_wav: # not voice conversion for sen in sens: if hasattr(self.tts_model, "synthesize"): - sp_name = "random" if speaker_name is None else speaker_name outputs = self.tts_model.synthesize( text=sen, config=self.tts_config, - speaker_id=sp_name, + speaker_id=speaker_name, voice_dirs=self.voice_dir, + d_vector=speaker_embedding, **kwargs, ) else: diff --git a/recipes/ljspeech/delightful_tts/train_delightful_tts.py b/recipes/ljspeech/delightful_tts/train_delightful_tts.py new file mode 100644 index 0000000000..81e40c84ae --- /dev/null +++ b/recipes/ljspeech/delightful_tts/train_delightful_tts.py @@ -0,0 +1,84 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.delightful_tts import DelightfulTTS, DelightfulTtsArgs, VocoderConfig +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio.processor import AudioProcessor + +data_path = "" +output_path = os.path.dirname(os.path.abspath(__file__)) + +dataset_config = BaseDatasetConfig( + dataset_name="ljspeech", formatter="ljspeech", meta_file_train="metadata.csv", path=data_path +) + +audio_config = DelightfulTtsAudioConfig() +model_args = DelightfulTtsArgs() + +vocoder_config = VocoderConfig() + +delightful_tts_config = DelightfulTTSConfig( + run_name="delightful_tts_ljspeech", + run_description="Train like in delightful tts paper.", + model_args=model_args, + audio=audio_config, + vocoder=vocoder_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=10, + num_eval_loader_workers=10, + precompute_num_workers=10, + batch_group_size=2, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=50, + print_eval=False, + mixed_precision=True, + output_path=output_path, + datasets=[dataset_config], + start_by_longest=False, + eval_split_size=0.1, + binary_align_loss_alpha=0.0, + use_attn_priors=False, + lr_gen=4e-1, + lr=4e-1, + lr_disc=4e-1, + max_text_len=130, +) + +tokenizer, config = TTSTokenizer.init_from_config(delightful_tts_config) + +ap = AudioProcessor.init_from_config(config) + + +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +model = DelightfulTTS(ap=ap, config=config, tokenizer=tokenizer, speaker_manager=None) + +trainer = Trainer( + TrainerArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, +) + +trainer.fit() diff --git a/recipes/vctk/delightful_tts/train_delightful_tts.py b/recipes/vctk/delightful_tts/train_delightful_tts.py new file mode 100644 index 0000000000..1980303543 --- /dev/null +++ b/recipes/vctk/delightful_tts/train_delightful_tts.py @@ -0,0 +1,86 @@ +import os + +from clearml import Task +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTTSE2e, VocoderConfig +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio.processor import AudioProcessor + +task = Task.init(project_name="delightful-tts", task_name="vctk") +data_path = "/raid/datasets/vctk_v092_48khz_removed_silence_silero_vad" +output_path = os.path.dirname(os.path.abspath(__file__)) + + +dataset_config = BaseDatasetConfig(dataset_name="vctk", meta_file_train="", path=data_path, language="en-us") + +audio_config = DelightfulTtsAudioConfig() + +model_args = DelightfulTtsArgs() + +vocoder_config = VocoderConfig() + +something_tts_config = DelightfulTTSConfig( + run_name="delightful_tts_e2e_ljspeech", + run_description="Train like in delightful tts paper.", + model_args=model_args, + audio=audio_config, + vocoder=vocoder_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=10, + num_eval_loader_workers=10, + precompute_num_workers=40, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=50, + print_eval=False, + mixed_precision=True, + output_path=output_path, + datasets=[dataset_config], + start_by_longest=True, + binary_align_loss_alpha=0.0, + use_attn_priors=False, + max_text_len=60, + steps_to_start_discriminator=10000, +) + +tokenizer, config = TTSTokenizer.init_from_config(something_tts_config) + +ap = AudioProcessor.init_from_config(config) + + +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + + +speaker_manager = SpeakerManager() +speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name") +config.model_args.num_speakers = speaker_manager.num_speakers + + +model = DelightfulTTSE2e( + ap=ap, config=config, tokenizer=tokenizer, speaker_manager=speaker_manager, emotion_manager=None +) + +trainer = Trainer( + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples +) + +trainer.fit() diff --git a/tests/tts_tests2/__init__.py b/tests/tts_tests2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests2/test_align_tts_train.py similarity index 100% rename from tests/tts_tests/test_align_tts_train.py rename to tests/tts_tests2/test_align_tts_train.py diff --git a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py new file mode 100644 index 0000000000..e6d04747ff --- /dev/null +++ b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py @@ -0,0 +1,98 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +audio_config = DelightfulTtsAudioConfig() +model_args = DelightfulTtsArgs( + use_speaker_embedding=False, d_vector_dim=256, use_d_vector_file=True, speaker_embedding_channels=256 +) + +vocoder_config = VocoderConfig() + +config = DelightfulTTSConfig( + model_args=model_args, + audio=audio_config, + vocoder=vocoder_config, + batch_size=2, + eval_batch_size=8, + compute_f0=True, + run_eval=True, + test_delay_epochs=-1, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models + epochs=1, + print_step=1, + print_eval=True, + binary_align_loss_alpha=0.0, + use_attn_priors=False, + test_sentences=["Be a voice, not an echo."], + output_path=output_path, + use_speaker_embedding=False, + use_d_vector_file=True, + d_vector_file="tests/data/ljspeech/speakers.json", + d_vector_dim=256, + speaker_embedding_channels=256, +) + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = False +config.model_args.use_d_vector_file = True +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.d_vector_dim = 256 + + +config.save_json(config_path) + +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) + +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +speaker_id = "ljspeech-1" +continue_speakers_path = config.d_vector_file + +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --speakers_file_path {continue_speakers_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) +shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/") diff --git a/tests/tts_tests2/test_delightful_tts_emb_spk.py b/tests/tts_tests2/test_delightful_tts_emb_spk.py new file mode 100644 index 0000000000..d72536d859 --- /dev/null +++ b/tests/tts_tests2/test_delightful_tts_emb_spk.py @@ -0,0 +1,92 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +audio_config = DelightfulTtsAudioConfig() +model_args = DelightfulTtsArgs(use_speaker_embedding=False) + +vocoder_config = VocoderConfig() + +config = DelightfulTTSConfig( + model_args=model_args, + audio=audio_config, + vocoder=vocoder_config, + batch_size=2, + eval_batch_size=8, + compute_f0=True, + run_eval=True, + test_delay_epochs=-1, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models + epochs=1, + print_step=1, + print_eval=True, + binary_align_loss_alpha=0.0, + use_attn_priors=False, + test_sentences=["Be a voice, not an echo."], + output_path=output_path, + num_speakers=4, + use_speaker_embedding=True, +) + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = True +config.model_args.use_d_vector_file = False +config.model_args.d_vector_file = None +config.model_args.d_vector_dim = 256 + + +config.save_json(config_path) + +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.dataset_name ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) + +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech" +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) +shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/") diff --git a/tests/tts_tests2/test_delightful_tts_layers.py b/tests/tts_tests2/test_delightful_tts_layers.py new file mode 100644 index 0000000000..073bb1eb5a --- /dev/null +++ b/tests/tts_tests2/test_delightful_tts_layers.py @@ -0,0 +1,91 @@ +import torch + +from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig +from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig +from TTS.tts.utils.helpers import rand_segments +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.vocoder.models.hifigan_generator import HifiganGenerator + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +args = DelightfulTtsArgs() +v_args = VocoderConfig() + + +config = DelightfulTTSConfig( + model_args=args, + # compute_f0=True, + # f0_cache_path=os.path.join(output_path, "f0_cache"), + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + # phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), +) + +tokenizer, config = TTSTokenizer.init_from_config(config) + + +def test_acoustic_model(): + dummy_tokens = torch.rand((1, 41)).long().to(device) + dummy_text_lens = torch.tensor([41]).to(device) + dummy_spec = torch.rand((1, 100, 207)).to(device) + dummy_spec_lens = torch.tensor([207]).to(device) + dummy_pitch = torch.rand((1, 1, 207)).long().to(device) + dummy_energy = torch.rand((1, 1, 207)).long().to(device) + + args.out_channels = 100 + args.num_mels = 100 + + acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device) + + output = acoustic_model( + tokens=dummy_tokens, + src_lens=dummy_text_lens, + mel_lens=dummy_spec_lens, + mels=dummy_spec, + pitches=dummy_pitch, + energies=dummy_energy, + attn_priors=None, + d_vectors=None, + speaker_idx=None, + ) + assert list(output["model_outputs"].shape) == [1, 207, 100] + output["model_outputs"].sum().backward() + + +def test_hifi_decoder(): + dummy_input = torch.rand((1, 207, 100)).to(device) + dummy_text_lens = torch.tensor([41]).to(device) + dummy_spec = torch.rand((1, 100, 207)).to(device) + dummy_spec_lens = torch.tensor([207]).to(device) + dummy_pitch = torch.rand((1, 1, 207)).long().to(device) + dummy_energy = torch.rand((1, 1, 207)).long().to(device) + + waveform_decoder = HifiganGenerator( + 100, + 1, + v_args.resblock_type_decoder, + v_args.resblock_dilation_sizes_decoder, + v_args.resblock_kernel_sizes_decoder, + v_args.upsample_kernel_sizes_decoder, + v_args.upsample_initial_channel_decoder, + v_args.upsample_rates_decoder, + inference_padding=0, + cond_channels=0, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ).to(device) + + vocoder_input_slices, slice_ids = rand_segments( # pylint: disable=unused-variable + x=dummy_input.transpose(1, 2), + x_lengths=dummy_spec_lens, + segment_size=32, + let_short_samples=True, + pad_short=True, + ) + + outputs = waveform_decoder(x=vocoder_input_slices.detach()) + assert list(outputs.shape) == [1, 1, 8192] + outputs.sum().backward() diff --git a/tests/tts_tests2/test_delightful_tts_train.py b/tests/tts_tests2/test_delightful_tts_train.py new file mode 100644 index 0000000000..cef6574546 --- /dev/null +++ b/tests/tts_tests2/test_delightful_tts_train.py @@ -0,0 +1,97 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +audio_config = DelightfulTtsAudioConfig() +model_args = DelightfulTtsArgs() + +vocoder_config = VocoderConfig() + + +config = DelightfulTTSConfig( + audio=audio_config, + batch_size=2, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache_delightful/", ## delightful f0 cache is incompatible with other models + run_eval=True, + test_delay_epochs=-1, + binary_align_loss_alpha=0.0, + epochs=1, + print_step=1, + use_attn_priors=False, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + use_speaker_embedding=False, +) +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{'cpu'}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs -1" +) + +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == -1 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) +shutil.rmtree("tests/data/ljspeech/f0_cache_delightful/") diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py similarity index 100% rename from tests/tts_tests/test_fast_pitch_speaker_emb_train.py rename to tests/tts_tests2/test_fast_pitch_speaker_emb_train.py diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests2/test_fast_pitch_train.py similarity index 100% rename from tests/tts_tests/test_fast_pitch_train.py rename to tests/tts_tests2/test_fast_pitch_train.py diff --git a/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py similarity index 100% rename from tests/tts_tests/test_fastspeech_2_speaker_emb_train.py rename to tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py diff --git a/tests/tts_tests/test_fastspeech_2_train.py b/tests/tts_tests2/test_fastspeech_2_train.py similarity index 100% rename from tests/tts_tests/test_fastspeech_2_train.py rename to tests/tts_tests2/test_fastspeech_2_train.py diff --git a/tests/tts_tests/test_feed_forward_layers.py b/tests/tts_tests2/test_feed_forward_layers.py similarity index 100% rename from tests/tts_tests/test_feed_forward_layers.py rename to tests/tts_tests2/test_feed_forward_layers.py diff --git a/tests/tts_tests/test_forward_tts.py b/tests/tts_tests2/test_forward_tts.py similarity index 100% rename from tests/tts_tests/test_forward_tts.py rename to tests/tts_tests2/test_forward_tts.py diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests2/test_glow_tts.py similarity index 100% rename from tests/tts_tests/test_glow_tts.py rename to tests/tts_tests2/test_glow_tts.py diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests2/test_glow_tts_d-vectors_train.py similarity index 100% rename from tests/tts_tests/test_glow_tts_d-vectors_train.py rename to tests/tts_tests2/test_glow_tts_d-vectors_train.py diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py similarity index 100% rename from tests/tts_tests/test_glow_tts_speaker_emb_train.py rename to tests/tts_tests2/test_glow_tts_speaker_emb_train.py diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests2/test_glow_tts_train.py similarity index 100% rename from tests/tts_tests/test_glow_tts_train.py rename to tests/tts_tests2/test_glow_tts_train.py