Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/speecht5 bug #28481

Merged
merged 9 commits into from
Jan 16, 2024
21 changes: 7 additions & 14 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,24 +653,19 @@ def __init__(self, config):
super().__init__()
self.config = config

self.layers = nn.ModuleList(
[
nn.Linear(
config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units,
config.speech_decoder_prenet_units,
)
for i in range(config.speech_decoder_prenet_layers)
]
)
self.layers = nn.ModuleList([
nn.Linear(
config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units,
config.speech_decoder_prenet_units
) for i in range(config.speech_decoder_prenet_layers)
])

self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size)

self.encode_positions = SpeechT5ScaledPositionalEncoding(
config.positional_dropout,
config.hidden_size,
config.max_speech_positions,
)

self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)

def _consistent_dropout(self, inputs_embeds, p):
Expand All @@ -695,9 +690,7 @@ def forward(

if speaker_embeddings is not None:
speaker_embeddings = nn.functional.normalize(speaker_embeddings)
speaker_embeddings = speaker_embeddings.unsqueeze(1)
speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1)
speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1)
Comment on lines -698 to -700
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I see it, there are two possible situations:

  1. As many speaker embeddings as sample (i.e batch size = number of speaker embeddings)
  2. One-to-many speaker embeddings (i.e one speaker embeddings for every sample of the batch)

Your proposed solution addresses situation 1., and the previous solution addressed situation 2.

To be complete, we should have a code that addresses both situations ! (and that throws an Error in other cases).

speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))

Expand Down
137 changes: 97 additions & 40 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,107 +1029,164 @@ def _mock_init_weights(self, module):
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
@cached_property
def default_model(self):
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device)

@cached_property
def default_processor(self):
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")

@cached_property
def default_vocoder(self):
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device)

def test_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor

set_seed(555) # make deterministic

speaker_embeddings = torch.zeros((1, 512)).to(torch_device)

input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)

speaker_embeddings = torch.zeros((1, 512), device=torch_device)

# Generate speech and validate output dimensions
set_seed(555) # Ensure deterministic behavior
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins))

set_seed(555) # make deterministic
num_mel_bins = model.config.num_mel_bins
self.assertEqual(
generated_speech.shape[1],
num_mel_bins,
"Generated speech output has an unexpected number of mel bins."
)

# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask
# Validate generation with additional kwargs using model.generate;
# same method than generate_speech
set_seed(555) # Reset seed for consistent results
generated_speech_with_generate = model.generate(
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
)
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins))
self.assertEqual(
generated_speech_with_generate.shape,
generated_speech.shape,
"Shape mismatch between generate_speech and generate methods."
)

def test_batch_generation(self):
model = self.default_model
model.to(torch_device)
processor = self.default_processor
vocoder = self.default_vocoder
set_seed(555) # make deterministic

input_text = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"nor is mister quilter's manner less interesting than his matter",
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
]
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
inputs = processor(
text=input_text, padding="max_length", max_length=128, return_tensors="pt"
).to(torch_device)
speaker_embeddings = torch.zeros((len(input_text), 512), device=torch_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nice to test it with different speaker embeddings (e.g random + seed) !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


speaker_embeddings = torch.zeros((1, 512), device=torch_device)
# Generate spectrograms
set_seed(555) # Ensure deterministic behavior
spectrograms, spectrogram_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
return_output_lengths=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following my previous commentary, we should also test the one-to-many situation !

self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins))

# Validate generated spectrogram dimensions
expected_batch_size = len(input_text)
num_mel_bins = model.config.num_mel_bins
actual_batch_size, _, actual_num_mel_bins = spectrograms.shape
self.assertEqual(
actual_batch_size,
expected_batch_size,
"Batch size of generated spectrograms is incorrect."
)
self.assertEqual(
actual_num_mel_bins,
num_mel_bins,
"Number of mel bins in batch generated spectrograms is incorrect."
)

# Generate waveforms using the vocoder
waveforms = vocoder(spectrograms)
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]

# Check waveform results are the same with or without using vocder
set_seed(555)
# Validate generation with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=True,
)
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8))
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder)

# Check waveform results are the same with return_concrete_lengths=True/False
set_seed(555)
# Check consistency between waveforms generated with and without standalone vocoder
self.assertTrue(
torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),
"Mismatch in waveforms generated with and without the standalone vocoder."
)
self.assertEqual(
waveform_lengths,
waveform_lengths_with_vocoder,
"Waveform lengths differ between standalone and integrated vocoder generation."
)

# Test generation consistency without returning lengths
set_seed(555) # Reset seed for consistent results
waveforms_with_vocoder_no_lengths = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
attention_mask=inputs["attention_mask"],
vocoder=vocoder,
return_output_lengths=False,
)
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8))

# Check results when batching are consistent with results without batching
# Validate waveform consistency without length information
self.assertTrue(
torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),
"Waveforms differ when generated with and without length information."
)

# Validate batch vs. single instance generation consistency
single_speaker_embedding = torch.zeros((1, 512), device=torch_device)
for i, text in enumerate(input_text):
set_seed(555) # make deterministic
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
spectrogram = model.generate_speech(
inputs = processor(
text=text, padding="max_length", max_length=128, return_tensors="pt"
).to(torch_device)
set_seed(555) # Reset seed for consistent results
single_spectrogram = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
speaker_embeddings=single_speaker_embedding,
)
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape)
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3))
waveform = vocoder(spectrogram)
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape)
# Check whether waveforms are the same with/without passing vocoder
set_seed(555)
waveform_with_vocoder = model.generate_speech(

# Check spectrogram shape consistency
self.assertEqual(
single_spectrogram.shape,
spectrograms[i][:spectrogram_lengths[i]].shape,
"Mismatch in spectrogram shape between batch and single instance generation."
)

# Generate and validate waveform for single instance
waveform = vocoder(single_spectrogram)
self.assertEqual(
waveform.shape,
waveforms[i][:waveform_lengths[i]].shape,
"Mismatch in waveform shape between batch and single instance generation."
)

# Check waveform consistency with integrated vocoder
set_seed(555) # Reset seed for consistent results
waveform_with_integrated_vocoder = model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
speaker_embeddings=single_speaker_embedding,
vocoder=vocoder,
)
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8))
self.assertTrue(
torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),
"Mismatch in waveform between standalone and integrated vocoder for single instance generation."
)


@require_torch
Expand Down