-
Notifications
You must be signed in to change notification settings - Fork 611
Closed
Description
Here :
Line 48 in bd4c341
| x = x + pitch_embedding |
and then
Line 50 in bd4c341
| energy_prediction = self.energy_predictor(x) |
But as per paper detail the input of Energy predictor should be output of length regulator not the output of pitch predictor.
See the fastspeech 2 diagram clearly input of Energy predictor is x output of length regulator without pitch component.
Actual code should be like:
def forward(self, x, duration_target=None, pitch_target=None, energy_target=None, max_length=None):
duration_prediction = self.duration_predictor(x)
if duration_target is not None:
x, mel_pos = self.length_regulator(x, duration_target, max_length)
else:
duration_rounded = torch.round(duration_prediction)
x, mel_pos = self.length_regulator(x, duration_rounded)
pitch_prediction = self.pitch_predictor(x)
if pitch_target is not None:
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_target, self.pitch_bins))
else:
pitch_embedding = self.pitch_embedding(torch.bucketize(pitch_prediction, self.pitch_bins))
energy_prediction = self.energy_predictor(x)
if energy_target is not None:
energy_embedding = self.energy_embedding(torch.bucketize(energy_target, self.energy_bins))
else:
energy_embedding = self.energy_embedding(torch.bucketize(energy_prediction, self.energy_bins))
x = x + pitch_embedding
x = x + energy_embedding
return x, duration_prediction, pitch_prediction, energy_prediction, mel_pos
rafaelvalle
Metadata
Metadata
Assignees
Labels
No labels