Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XTTS GPT padding and inference issues #3216

Merged
merged 3 commits into from
Nov 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Remove ununsed code
  • Loading branch information
Edresson committed Nov 14, 2023
commit f5fd31f8e83e37fc991c7eb9ae874c5a90eabb2e
31 changes: 0 additions & 31 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torchaudio
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down Expand Up @@ -308,26 +307,6 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int =
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
def get_diffusion_cond_latents(self, audio, sr):
from math import ceil

diffusion_conds = []
CHUNK_SIZE = 102400
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
do_normalization=False,
device=self.device,
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent

@torch.inference_mode()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
Expand Down Expand Up @@ -575,16 +554,6 @@ def inference(
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break

if length_scale != 1.0:
gpt_latents = F.interpolate(
Expand Down
Loading