Skip to content

Commit

Permalink
Merge pull request #3208 from coqui-ai/fix_max_mel_len
Browse files Browse the repository at this point in the history
fix max generation length for XTTS
  • Loading branch information
erogol authored Nov 13, 2023
2 parents f32a465 + b85536b commit ac3df40
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
self.heads = heads
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
self.max_prompt_tokens = max_prompt_tokens
Expand Down Expand Up @@ -598,7 +599,7 @@ def generate(
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
Expand All @@ -611,7 +612,7 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens,
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
do_stream=True,
**hf_generate_kwargs,
)

0 comments on commit ac3df40

Please sign in to comment.