Skip to content

Commit

Permalink
Load reference in 24khz to avoid issued with multiple sr references
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson authored and erogol committed Nov 6, 2023
1 parent 00294ff commit 72b2bac
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
30 changes: 1 addition & 29 deletions TTS/tts/layers/xtts/trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

Expand Down Expand Up @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
return rel_clip, rel_clip.shape[-1], cond_idxs


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
Expand Down
36 changes: 32 additions & 4 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def wav_to_mel_cloning(
return mel


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Expand Down Expand Up @@ -404,6 +429,7 @@ def get_conditioning_latents(
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
):
# deal with multiples references
if not isinstance(audio_path, list):
Expand All @@ -415,22 +441,24 @@ def get_conditioning_latents(
audios = []
speaker_embedding = None
for file_path in audio_paths:
audio, sr = torchaudio.load(file_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, sr)
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T]
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
Expand Down

0 comments on commit 72b2bac

Please sign in to comment.