From 07ae53e6e77ec6ff4fb25fbacfec4b11cfc82749 Mon Sep 17 00:00:00 2001 From: Nima Yaqmuri <62163525+NimaYaqmuri@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:44:28 +0330 Subject: [PATCH] Fix/speecht5 bug (#28481) * Fix bug in SpeechT5 speech decoder prenet's forward method - Removed redundant `repeat` operation on speaker_embeddings in the forward method. This line was erroneously duplicating the embeddings, leading to incorrect input size for concatenation and performance issues. - Maintained original functionality of the method, ensuring the integrity of the speech decoder prenet's forward pass remains intact. - This change resolves a critical bug affecting the model's performance in handling speaker embeddings. * Refactor SpeechT5 text to speech integration tests - Updated SpeechT5ForTextToSpeechIntegrationTests to accommodate the variability in sequence lengths due to dropout in the speech decoder pre-net. This change ensures that our tests are robust against random variations in generated speech, enhancing the reliability of our test suite. - Removed hardcoded dimensions in test assertions. Replaced with dynamic checks based on model configuration and seed settings, ensuring tests remain valid across different runs and configurations. - Added new test cases to thoroughly validate the shapes of generated spectrograms and waveforms. These tests leverage seed settings to ensure consistent and predictable behavior in testing, addressing potential issues in speech generation and vocoder processing. - Fixed existing test cases where incorrect assumptions about output shapes led to potential errors. * Fix bug in SpeechT5 speech decoder prenet's forward method - Removed redundant `repeat` operation on speaker_embeddings in the forward method. This line was erroneously duplicating the embeddings, leading to incorrect input size for concatenation and performance issues. - Maintained original functionality of the method, ensuring the integrity of the speech decoder prenet's forward pass remains intact. - This change resolves a critical bug affecting the model's performance in handling speaker embeddings. * Refactor SpeechT5 text to speech integration tests - Updated SpeechT5ForTextToSpeechIntegrationTests to accommodate the variability in sequence lengths due to dropout in the speech decoder pre-net. This change ensures that our tests are robust against random variations in generated speech, enhancing the reliability of our test suite. - Removed hardcoded dimensions in test assertions. Replaced with dynamic checks based on model configuration and seed settings, ensuring tests remain valid across different runs and configurations. - Added new test cases to thoroughly validate the shapes of generated spectrograms and waveforms. These tests leverage seed settings to ensure consistent and predictable behavior in testing, addressing potential issues in speech generation and vocoder processing. - Fixed existing test cases where incorrect assumptions about output shapes led to potential errors. * Enhance handling of speaker embeddings in SpeechT5 - Refined the generate and generate_speech functions in the SpeechT5 class to robustly handle two scenarios for speaker embeddings: matching the batch size (one embedding per sample) and one-to-many (a single embedding for all samples in the batch). - The update includes logic to repeat the speaker embedding when a single embedding is provided for multiple samples, and a ValueError is raised for any mismatched dimensions. - Also added corresponding test cases to validate both scenarios, ensuring complete coverage and functionality for diverse speaker embedding situations. * Improve Test Robustness with Randomized Speaker Embeddings --- .../models/speecht5/modeling_speecht5.py | 26 +- .../models/speecht5/test_modeling_speecht5.py | 228 +++++++++++++++--- 2 files changed, 215 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 94334e76ef4b17..bbdaaec473fa78 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -664,13 +664,11 @@ def __init__(self, config): ) self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size) - self.encode_positions = SpeechT5ScaledPositionalEncoding( config.positional_dropout, config.hidden_size, config.max_speech_positions, ) - self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) def _consistent_dropout(self, inputs_embeds, p): @@ -695,9 +693,7 @@ def forward( if speaker_embeddings is not None: speaker_embeddings = nn.functional.normalize(speaker_embeddings) - speaker_embeddings = speaker_embeddings.unsqueeze(1) - speaker_embeddings = speaker_embeddings.expand(-1, inputs_embeds.size(1), -1) - speaker_embeddings = speaker_embeddings.repeat(inputs_embeds.size(0), 1, 1) + speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) @@ -2825,6 +2821,16 @@ def generate( `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch_size." + ) + return _generate_speech( self, input_ids, @@ -2911,6 +2917,16 @@ def generate_speech( `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch size." + ) + return _generate_speech( self, input_ids, diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index c6b4b24873a2fe..7849b59d2935a7 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1029,7 +1029,7 @@ def _mock_init_weights(self, module): class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): @cached_property def default_model(self): - return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device) @cached_property def default_processor(self): @@ -1037,37 +1037,40 @@ def default_processor(self): @cached_property def default_vocoder(self): - return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device) def test_generation(self): model = self.default_model - model.to(torch_device) processor = self.default_processor - set_seed(555) # make deterministic - - speaker_embeddings = torch.zeros((1, 512)).to(torch_device) - - input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" + input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) + speaker_embeddings = torch.zeros((1, 512), device=torch_device) + # Generate speech and validate output dimensions + set_seed(555) # Ensure deterministic behavior generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings) - self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins)) - - set_seed(555) # make deterministic + num_mel_bins = model.config.num_mel_bins + self.assertEqual( + generated_speech.shape[1], num_mel_bins, "Generated speech output has an unexpected number of mel bins." + ) - # test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask + # Validate generation with additional kwargs using model.generate; + # same method than generate_speech + set_seed(555) # Reset seed for consistent results generated_speech_with_generate = model.generate( input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings ) - self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins)) + self.assertEqual( + generated_speech_with_generate.shape, + generated_speech.shape, + "Shape mismatch between generate_speech and generate methods.", + ) - def test_batch_generation(self): + def test_one_to_many_generation(self): model = self.default_model - model.to(torch_device) processor = self.default_processor vocoder = self.default_vocoder - set_seed(555) # make deterministic input_text = [ "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", @@ -1075,20 +1078,32 @@ def test_batch_generation(self): "he tells us that at this festive season of the year with christmas and rosebeaf looming before us", ] inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) - speaker_embeddings = torch.zeros((1, 512), device=torch_device) + + # Generate spectrograms + set_seed(555) # Ensure deterministic behavior spectrograms, spectrogram_lengths = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, attention_mask=inputs["attention_mask"], return_output_lengths=True, ) - self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins)) + + # Validate generated spectrogram dimensions + expected_batch_size = len(input_text) + num_mel_bins = model.config.num_mel_bins + actual_batch_size, _, actual_num_mel_bins = spectrograms.shape + self.assertEqual(actual_batch_size, expected_batch_size, "Batch size of generated spectrograms is incorrect.") + self.assertEqual( + actual_num_mel_bins, num_mel_bins, "Number of mel bins in batch generated spectrograms is incorrect." + ) + + # Generate waveforms using the vocoder waveforms = vocoder(spectrograms) waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] - # Check waveform results are the same with or without using vocder - set_seed(555) + # Validate generation with integrated vocoder + set_seed(555) # Reset seed for consistent results waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, @@ -1096,11 +1111,20 @@ def test_batch_generation(self): vocoder=vocoder, return_output_lengths=True, ) - self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8)) - self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder) - # Check waveform results are the same with return_concrete_lengths=True/False - set_seed(555) + # Check consistency between waveforms generated with and without standalone vocoder + self.assertTrue( + torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8), + "Mismatch in waveforms generated with and without the standalone vocoder.", + ) + self.assertEqual( + waveform_lengths, + waveform_lengths_with_vocoder, + "Waveform lengths differ between standalone and integrated vocoder generation.", + ) + + # Test generation consistency without returning lengths + set_seed(555) # Reset seed for consistent results waveforms_with_vocoder_no_lengths = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, @@ -1108,28 +1132,164 @@ def test_batch_generation(self): vocoder=vocoder, return_output_lengths=False, ) - self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8)) - # Check results when batching are consistent with results without batching + # Validate waveform consistency without length information + self.assertTrue( + torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8), + "Waveforms differ when generated with and without length information.", + ) + + # Validate batch vs. single instance generation consistency for i, text in enumerate(input_text): - set_seed(555) # make deterministic inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + set_seed(555) # Reset seed for consistent results spectrogram = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, ) - self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape) - self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3)) + + # Check spectrogram shape consistency + self.assertEqual( + spectrogram.shape, + spectrograms[i][: spectrogram_lengths[i]].shape, + "Mismatch in spectrogram shape between batch and single instance generation.", + ) + + # Generate and validate waveform for single instance waveform = vocoder(spectrogram) - self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape) - # Check whether waveforms are the same with/without passing vocoder - set_seed(555) - waveform_with_vocoder = model.generate_speech( + self.assertEqual( + waveform.shape, + waveforms[i][: waveform_lengths[i]].shape, + "Mismatch in waveform shape between batch and single instance generation.", + ) + + # Check waveform consistency with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveform_with_integrated_vocoder = model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder, ) - self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8)) + self.assertTrue( + torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8), + "Mismatch in waveform between standalone and integrated vocoder for single instance generation.", + ) + + def test_batch_generation(self): + model = self.default_model + processor = self.default_processor + vocoder = self.default_vocoder + + input_text = [ + "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", + "nor is mister quilter's manner less interesting than his matter", + "he tells us that at this festive season of the year with christmas and rosebeaf looming before us", + ] + inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + set_seed(555) # Ensure deterministic behavior + speaker_embeddings = torch.randn((len(input_text), 512), device=torch_device) + + # Generate spectrograms + set_seed(555) # Reset seed for consistent results + spectrograms, spectrogram_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + return_output_lengths=True, + ) + + # Validate generated spectrogram dimensions + expected_batch_size = len(input_text) + num_mel_bins = model.config.num_mel_bins + actual_batch_size, _, actual_num_mel_bins = spectrograms.shape + self.assertEqual( + actual_batch_size, + expected_batch_size, + "Batch size of generated spectrograms is incorrect.", + ) + self.assertEqual( + actual_num_mel_bins, + num_mel_bins, + "Number of mel bins in batch generated spectrograms is incorrect.", + ) + + # Generate waveforms using the vocoder + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + + # Validate generation with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=True, + ) + + # Check consistency between waveforms generated with and without standalone vocoder + self.assertTrue( + torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8), + "Mismatch in waveforms generated with and without the standalone vocoder.", + ) + self.assertEqual( + waveform_lengths, + waveform_lengths_with_vocoder, + "Waveform lengths differ between standalone and integrated vocoder generation.", + ) + + # Test generation consistency without returning lengths + set_seed(555) # Reset seed for consistent results + waveforms_with_vocoder_no_lengths = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=speaker_embeddings, + attention_mask=inputs["attention_mask"], + vocoder=vocoder, + return_output_lengths=False, + ) + + # Validate waveform consistency without length information + self.assertTrue( + torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8), + "Waveforms differ when generated with and without length information.", + ) + + # Validate batch vs. single instance generation consistency + for i, text in enumerate(input_text): + inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) + current_speaker_embedding = speaker_embeddings[i].unsqueeze(0) + set_seed(555) # Reset seed for consistent results + spectrogram = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=current_speaker_embedding, + ) + + # Check spectrogram shape consistency + self.assertEqual( + spectrogram.shape, + spectrograms[i][: spectrogram_lengths[i]].shape, + "Mismatch in spectrogram shape between batch and single instance generation.", + ) + + # Generate and validate waveform for single instance + waveform = vocoder(spectrogram) + self.assertEqual( + waveform.shape, + waveforms[i][: waveform_lengths[i]].shape, + "Mismatch in waveform shape between batch and single instance generation.", + ) + + # Check waveform consistency with integrated vocoder + set_seed(555) # Reset seed for consistent results + waveform_with_integrated_vocoder = model.generate_speech( + input_ids=inputs["input_ids"], + speaker_embeddings=current_speaker_embedding, + vocoder=vocoder, + ) + self.assertTrue( + torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8), + "Mismatch in waveform between standalone and integrated vocoder for single instance generation.", + ) @require_torch