diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 5d8b2ae66b..8cb90ad0f8 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -2,13 +2,10 @@ import random import sys -import numpy as np import torch import torch.nn.functional as F import torch.utils.data -import torchaudio -from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load -from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load +from TTS.tts.models.xtts import load_audio torch.set_num_threads(1) @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, return rel_clip, rel_clip.shape[-1], cond_idxs -def load_audio(audiopath, sampling_rate): - # better load setting following: https://github.com/faroit/python_audio_loading_benchmark - if audiopath[-4:] == ".mp3": - # it uses torchaudio with sox backend to load mp3 - audio, lsr = torchaudio_sox_load(audiopath) - else: - # it uses torchaudio soundfile backend to load all the others data type - audio, lsr = torchaudio_soundfile_load(audiopath) - - # stereo to mono if needed - if audio.size(0) != 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - if lsr != sampling_rate: - audio = torchaudio.functional.resample(audio, lsr, sampling_rate) - - # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. - # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. - if torch.any(audio > 10) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") - # clip audio invalid values - audio.clip_(-1, 1) - return audio - - class XTTSDataset(torch.utils.data.Dataset): def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): self.config = config diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index ce968053b4..fdaeb7deb8 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -67,6 +67,31 @@ def wav_to_mel_cloning( return mel +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + if audiopath[-4:] == ".mp3": + # it uses torchaudio with sox backend to load mp3 + audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath) + else: + # it uses torchaudio soundfile backend to load all the others data type + audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + def pad_or_truncate(t, length): """ Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. @@ -404,6 +429,7 @@ def get_conditioning_latents( max_ref_length=10, librosa_trim_db=None, sound_norm_refs=False, + load_sr=24000, ): # deal with multiples references if not isinstance(audio_path, list): @@ -415,8 +441,9 @@ def get_conditioning_latents( audios = [] speaker_embedding = None for file_path in audio_paths: - audio, sr = torchaudio.load(file_path) - audio = audio[:, : sr * max_ref_length].to(self.device) + # load the audio in 24khz to avoid issued with multiple sr references + audio = load_audio(file_path, load_sr) + audio = audio[:, : load_sr * max_ref_length].to(self.device) if audio.shape[0] > 1: audio = audio.mean(0, keepdim=True) if sound_norm_refs: @@ -424,13 +451,14 @@ def get_conditioning_latents( if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - speaker_embedding = self.get_speaker_embedding(audio, sr) + speaker_embedding = self.get_speaker_embedding(audio, load_sr) speaker_embeddings.append(speaker_embedding) + audios.append(audio) # use a merge of all references for gpt cond latents full_audio = torch.cat(audios, dim=-1) - gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T] + gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T] if speaker_embeddings: speaker_embedding = torch.stack(speaker_embeddings)