-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/speecht5 bug #28481
Fix/speecht5 bug #28481
Changes from 5 commits
55a8810
8da3d18
9b159e0
daacf11
aca772b
65121d5
b6e5a8c
7eb841b
8c813ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1029,107 +1029,164 @@ def _mock_init_weights(self, module): | |
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): | ||
@cached_property | ||
def default_model(self): | ||
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | ||
return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device) | ||
|
||
@cached_property | ||
def default_processor(self): | ||
return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | ||
|
||
@cached_property | ||
def default_vocoder(self): | ||
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | ||
return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device) | ||
|
||
def test_generation(self): | ||
model = self.default_model | ||
model.to(torch_device) | ||
processor = self.default_processor | ||
|
||
set_seed(555) # make deterministic | ||
|
||
speaker_embeddings = torch.zeros((1, 512)).to(torch_device) | ||
|
||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" | ||
input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." | ||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) | ||
|
||
speaker_embeddings = torch.zeros((1, 512), device=torch_device) | ||
|
||
# Generate speech and validate output dimensions | ||
set_seed(555) # Ensure deterministic behavior | ||
generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings) | ||
self.assertEqual(generated_speech.shape, (230, model.config.num_mel_bins)) | ||
|
||
set_seed(555) # make deterministic | ||
num_mel_bins = model.config.num_mel_bins | ||
self.assertEqual( | ||
generated_speech.shape[1], | ||
num_mel_bins, | ||
"Generated speech output has an unexpected number of mel bins." | ||
) | ||
|
||
# test model.generate, same method than generate_speech but with additional kwargs to absorb kwargs such as attention_mask | ||
# Validate generation with additional kwargs using model.generate; | ||
# same method than generate_speech | ||
set_seed(555) # Reset seed for consistent results | ||
generated_speech_with_generate = model.generate( | ||
input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings | ||
) | ||
self.assertEqual(generated_speech_with_generate.shape, (230, model.config.num_mel_bins)) | ||
self.assertEqual( | ||
generated_speech_with_generate.shape, | ||
generated_speech.shape, | ||
"Shape mismatch between generate_speech and generate methods." | ||
) | ||
|
||
def test_batch_generation(self): | ||
model = self.default_model | ||
model.to(torch_device) | ||
processor = self.default_processor | ||
vocoder = self.default_vocoder | ||
set_seed(555) # make deterministic | ||
|
||
input_text = [ | ||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", | ||
"nor is mister quilter's manner less interesting than his matter", | ||
"he tells us that at this festive season of the year with christmas and rosebeaf looming before us", | ||
] | ||
inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) | ||
inputs = processor( | ||
text=input_text, padding="max_length", max_length=128, return_tensors="pt" | ||
).to(torch_device) | ||
speaker_embeddings = torch.zeros((len(input_text), 512), device=torch_device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be nice to test it with different speaker embeddings (e.g random + seed) ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
||
speaker_embeddings = torch.zeros((1, 512), device=torch_device) | ||
# Generate spectrograms | ||
set_seed(555) # Ensure deterministic behavior | ||
spectrograms, spectrogram_lengths = model.generate_speech( | ||
input_ids=inputs["input_ids"], | ||
speaker_embeddings=speaker_embeddings, | ||
attention_mask=inputs["attention_mask"], | ||
return_output_lengths=True, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following my previous commentary, we should also test the one-to-many situation ! |
||
self.assertEqual(spectrograms.shape, (3, 262, model.config.num_mel_bins)) | ||
|
||
# Validate generated spectrogram dimensions | ||
expected_batch_size = len(input_text) | ||
num_mel_bins = model.config.num_mel_bins | ||
actual_batch_size, _, actual_num_mel_bins = spectrograms.shape | ||
self.assertEqual( | ||
actual_batch_size, | ||
expected_batch_size, | ||
"Batch size of generated spectrograms is incorrect." | ||
) | ||
self.assertEqual( | ||
actual_num_mel_bins, | ||
num_mel_bins, | ||
"Number of mel bins in batch generated spectrograms is incorrect." | ||
) | ||
|
||
# Generate waveforms using the vocoder | ||
waveforms = vocoder(spectrograms) | ||
waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] | ||
|
||
# Check waveform results are the same with or without using vocder | ||
set_seed(555) | ||
# Validate generation with integrated vocoder | ||
set_seed(555) # Reset seed for consistent results | ||
waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech( | ||
input_ids=inputs["input_ids"], | ||
speaker_embeddings=speaker_embeddings, | ||
attention_mask=inputs["attention_mask"], | ||
vocoder=vocoder, | ||
return_output_lengths=True, | ||
) | ||
self.assertTrue(torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8)) | ||
self.assertEqual(waveform_lengths, waveform_lengths_with_vocoder) | ||
|
||
# Check waveform results are the same with return_concrete_lengths=True/False | ||
set_seed(555) | ||
# Check consistency between waveforms generated with and without standalone vocoder | ||
self.assertTrue( | ||
torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8), | ||
"Mismatch in waveforms generated with and without the standalone vocoder." | ||
) | ||
self.assertEqual( | ||
waveform_lengths, | ||
waveform_lengths_with_vocoder, | ||
"Waveform lengths differ between standalone and integrated vocoder generation." | ||
) | ||
|
||
# Test generation consistency without returning lengths | ||
set_seed(555) # Reset seed for consistent results | ||
waveforms_with_vocoder_no_lengths = model.generate_speech( | ||
input_ids=inputs["input_ids"], | ||
speaker_embeddings=speaker_embeddings, | ||
attention_mask=inputs["attention_mask"], | ||
vocoder=vocoder, | ||
return_output_lengths=False, | ||
) | ||
self.assertTrue(torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8)) | ||
|
||
# Check results when batching are consistent with results without batching | ||
# Validate waveform consistency without length information | ||
self.assertTrue( | ||
torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8), | ||
"Waveforms differ when generated with and without length information." | ||
) | ||
|
||
# Validate batch vs. single instance generation consistency | ||
single_speaker_embedding = torch.zeros((1, 512), device=torch_device) | ||
for i, text in enumerate(input_text): | ||
set_seed(555) # make deterministic | ||
inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device) | ||
spectrogram = model.generate_speech( | ||
inputs = processor( | ||
text=text, padding="max_length", max_length=128, return_tensors="pt" | ||
).to(torch_device) | ||
set_seed(555) # Reset seed for consistent results | ||
single_spectrogram = model.generate_speech( | ||
input_ids=inputs["input_ids"], | ||
speaker_embeddings=speaker_embeddings, | ||
speaker_embeddings=single_speaker_embedding, | ||
) | ||
self.assertEqual(spectrogram.shape, spectrograms[i][: spectrogram_lengths[i]].shape) | ||
self.assertTrue(torch.allclose(spectrogram, spectrograms[i][: spectrogram_lengths[i]], atol=5e-3)) | ||
waveform = vocoder(spectrogram) | ||
self.assertEqual(waveform.shape, waveforms[i][: waveform_lengths[i]].shape) | ||
# Check whether waveforms are the same with/without passing vocoder | ||
set_seed(555) | ||
waveform_with_vocoder = model.generate_speech( | ||
|
||
# Check spectrogram shape consistency | ||
self.assertEqual( | ||
single_spectrogram.shape, | ||
spectrograms[i][:spectrogram_lengths[i]].shape, | ||
"Mismatch in spectrogram shape between batch and single instance generation." | ||
) | ||
|
||
# Generate and validate waveform for single instance | ||
waveform = vocoder(single_spectrogram) | ||
self.assertEqual( | ||
waveform.shape, | ||
waveforms[i][:waveform_lengths[i]].shape, | ||
"Mismatch in waveform shape between batch and single instance generation." | ||
) | ||
|
||
# Check waveform consistency with integrated vocoder | ||
set_seed(555) # Reset seed for consistent results | ||
waveform_with_integrated_vocoder = model.generate_speech( | ||
input_ids=inputs["input_ids"], | ||
speaker_embeddings=speaker_embeddings, | ||
speaker_embeddings=single_speaker_embedding, | ||
vocoder=vocoder, | ||
) | ||
self.assertTrue(torch.allclose(waveform, waveform_with_vocoder, atol=1e-8)) | ||
self.assertTrue( | ||
torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8), | ||
"Mismatch in waveform between standalone and integrated vocoder for single instance generation." | ||
) | ||
|
||
|
||
@require_torch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I see it, there are two possible situations:
batch size = number of speaker embeddings
)Your proposed solution addresses situation 1., and the previous solution addressed situation 2.
To be complete, we should have a code that addresses both situations ! (and that throws an
Error
in other cases).