Skip to content

Commit

Permalink
fix glow-tts inference and forward functions for handling cond_input
Browse files Browse the repository at this point in the history
and refactor its test
  • Loading branch information
erogol committed Jun 28, 2021
1 parent f840268 commit 6c495c6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
28 changes: 20 additions & 8 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def forward(
y_lengths: B
g: [B, C] or B
"""
y_max_length = y.size(2)
y = y.transpose(1, 2)
y_max_length = y.size(2)
# norm speaker embeddings
g = cond_input["x_vectors"]
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
if g is not None:
if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
Expand Down Expand Up @@ -196,19 +196,23 @@ def forward(
return outputs

@torch.no_grad()
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, cond_input={"x_vectors": None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
It was proposed in: https://arxiv.org/abs/2104.05557
Shapes:
x: [B, T]
x_lenghts: B
y: [B, C, T]
y: [B, T, C]
y_lengths: B
g: [B, C] or B
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
# norm speaker embeddings
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
if g is not None:
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
Expand Down Expand Up @@ -253,14 +257,18 @@ def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=
return outputs

@torch.no_grad()
def decoder_inference(self, y, y_lengths=None, g=None):
def decoder_inference(
self, y, y_lengths=None, cond_input={"x_vectors": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
y: [B, C, T]
y: [B, T, C]
y_lengths: B
g: [B, C] or B
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
# norm speaker embeddings
if g is not None:
if self.external_speaker_embedding_dim:
Expand All @@ -276,10 +284,14 @@ def decoder_inference(self, y, y_lengths=None, g=None):
# reverse decoder and predict
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)

return y, logdet
outputs = {}
outputs["model_outputs"] = y
outputs["logdet"] = logdet
return outputs

@torch.no_grad()
def inference(self, x, x_lengths, g=None):
def inference(self, x, x_lengths, cond_input={"x_vectors": None}): # pylint: disable=dangerous-default-value
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
if g is not None:
if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
Expand Down
25 changes: 16 additions & 9 deletions tests/tts_tests/test_glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_train_step():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)

Expand Down Expand Up @@ -114,10 +114,17 @@ def test_train_step():
optimizer = optim.Adam(model.parameters(), lr=0.001)
for _ in range(5):
optimizer.zero_grad()
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, None
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
loss_dict = criterion(
outputs["model_outputs"],
outputs["y_mean"],
outputs["y_log_scale"],
outputs["logdet"],
mel_lengths,
outputs["durations_log"],
outputs["total_durations_log"],
input_lengths,
)
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths)
loss = loss_dict["loss"]
loss.backward()
optimizer.step()
Expand All @@ -137,7 +144,7 @@ def test_inference():
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)

Expand Down Expand Up @@ -175,12 +182,12 @@ def test_inference():
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))

# inference encoder and decoder with MAS
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)

y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
y2 = model.decoder_inference(mel_spec, mel_lengths)

assert (
y_dec.shape == y.shape
y2["model_outputs"].shape == y["model_outputs"].shape
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
y.shape, y_dec.shape
y["model_outputs"].shape, y2["model_outputs"].shape
)

0 comments on commit 6c495c6

Please sign in to comment.