Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding neural HMM TTS #2271

Closed
wants to merge 83 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
405bffe
Adding encoder
shivammehta25 Nov 26, 2022
d607993
currently modifying hmm
shivammehta25 Nov 27, 2022
a324920
Adding hmm
shivammehta25 Nov 28, 2022
8628648
Adding overflow
shivammehta25 Nov 30, 2022
6ec83c4
Adding overflow setting up flat start
shivammehta25 Dec 1, 2022
783a982
Removing runs
shivammehta25 Dec 1, 2022
10f15e0
adding normalization parameters
shivammehta25 Dec 1, 2022
aff8b1f
Fixing models on same device
shivammehta25 Dec 1, 2022
62941d6
Training overflow and plotting evaluations
shivammehta25 Dec 2, 2022
f448ea4
Adding inference
shivammehta25 Dec 3, 2022
ff33837
At the end of epoch the test sentences are coming on cpu instead of gpu
shivammehta25 Dec 4, 2022
3edb0d2
Adding figures from model during training to monitor
shivammehta25 Dec 5, 2022
5fc800c
reverting tacotron2 training recipe
shivammehta25 Dec 5, 2022
427dfe5
fixing inference on gpu for test sentences on config
shivammehta25 Dec 5, 2022
ecc12c6
moving helpers and texts within overflows source code
shivammehta25 Dec 5, 2022
b86f3f8
renaming to overflow
shivammehta25 Dec 5, 2022
995ee93
moving loss to the model file
shivammehta25 Dec 5, 2022
5b0fe46
Fixing the rename
shivammehta25 Dec 5, 2022
5377f87
Model training but not plotting the test config sentences's audios
shivammehta25 Dec 5, 2022
bd5be6c
Formatting logs
shivammehta25 Dec 5, 2022
755aa6f
Changing model name to camelcase
shivammehta25 Dec 5, 2022
1350a4b
Fixing test log
shivammehta25 Dec 5, 2022
3c986fd
Fixing plotting bug
shivammehta25 Dec 6, 2022
4a5b1a0
Adding some tests
shivammehta25 Dec 6, 2022
5b1dabc
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Dec 7, 2022
f43d7e3
Adding more tests to overflow
shivammehta25 Dec 8, 2022
c3d0167
Adding all tests for overflow
shivammehta25 Dec 9, 2022
ddefe34
making changes to camel case in config
shivammehta25 Dec 9, 2022
c2df9f3
Adding information about parameters and docstring
shivammehta25 Dec 10, 2022
9927434
removing compute_mel_statistics moved statistic computation to the mo…
shivammehta25 Dec 10, 2022
340cd0b
Added overflow in readme
shivammehta25 Dec 10, 2022
aca3fe1
Adding more test cases, now it doesn't saves transition_p like tensor…
shivammehta25 Dec 11, 2022
e7c11dd
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Dec 14, 2022
7e2dbb1
uncommenting the approximation to stablize the training
shivammehta25 Dec 14, 2022
be09d6c
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Dec 14, 2022
282de93
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Dec 22, 2022
5df4fe8
Adding encoder
shivammehta25 Nov 26, 2022
fa25825
currently modifying hmm
shivammehta25 Nov 27, 2022
3cb0f78
Adding hmm
shivammehta25 Nov 28, 2022
9984afa
Adding overflow
shivammehta25 Nov 30, 2022
4dad45c
Adding overflow setting up flat start
shivammehta25 Dec 1, 2022
377bd3e
Removing runs
shivammehta25 Dec 1, 2022
a441c71
adding normalization parameters
shivammehta25 Dec 1, 2022
995ac14
Fixing models on same device
shivammehta25 Dec 1, 2022
97b985b
Training overflow and plotting evaluations
shivammehta25 Dec 2, 2022
227077a
Adding inference
shivammehta25 Dec 3, 2022
bea46cc
At the end of epoch the test sentences are coming on cpu instead of gpu
shivammehta25 Dec 4, 2022
03d028e
Adding figures from model during training to monitor
shivammehta25 Dec 5, 2022
fc3c641
reverting tacotron2 training recipe
shivammehta25 Dec 5, 2022
c429837
fixing inference on gpu for test sentences on config
shivammehta25 Dec 5, 2022
b804a12
moving helpers and texts within overflows source code
shivammehta25 Dec 5, 2022
3149b43
renaming to overflow
shivammehta25 Dec 5, 2022
8aff87a
moving loss to the model file
shivammehta25 Dec 5, 2022
8d7b0e7
Fixing the rename
shivammehta25 Dec 5, 2022
8aaffed
Model training but not plotting the test config sentences's audios
shivammehta25 Dec 5, 2022
648b2c3
Formatting logs
shivammehta25 Dec 5, 2022
d22c6c0
Changing model name to camelcase
shivammehta25 Dec 5, 2022
6e08e4f
Fixing test log
shivammehta25 Dec 5, 2022
9394ce0
Fixing plotting bug
shivammehta25 Dec 6, 2022
e115361
Adding some tests
shivammehta25 Dec 6, 2022
7a541b9
Adding more tests to overflow
shivammehta25 Dec 8, 2022
1dccc29
Adding all tests for overflow
shivammehta25 Dec 9, 2022
1b1bf1f
making changes to camel case in config
shivammehta25 Dec 9, 2022
916b98e
Adding information about parameters and docstring
shivammehta25 Dec 10, 2022
6eff37c
removing compute_mel_statistics moved statistic computation to the mo…
shivammehta25 Dec 10, 2022
8a8dd1d
Added overflow in readme
shivammehta25 Dec 10, 2022
e738c0c
Adding more test cases, now it doesn't saves transition_p like tensor…
shivammehta25 Dec 11, 2022
479c0cf
Handle espeak 1.48.15 (#2203)
erogol Dec 12, 2022
4f02e2c
Python API implementation (#2195)
erogol Dec 12, 2022
89b9868
Update README (#2204)
erogol Dec 12, 2022
684adb0
Adding missing key to formatter (#2194)
p0p4k Dec 12, 2022
55801cc
Add YourTTS VCTK recipe (#2198)
Edresson Dec 12, 2022
a0be902
Add Original YourTTS vocabulary for full transfer learning (#2206)
Edresson Dec 13, 2022
f3fe409
uncommenting the approximation to stablize the training
shivammehta25 Dec 14, 2022
aedd795
Adding pre-trained Overflow model (#2211)
erogol Dec 14, 2022
253b03f
Fixup overflow (#2218)
erogol Dec 14, 2022
c2ce4fb
Bump up to v0.10.0
erogol Dec 15, 2022
fd5ad8c
Add Ukrainian LADA (female) voice
egorsmkv Dec 16, 2022
1260c7f
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Dec 30, 2022
f73cd29
Merge branch 'coqui-ai:dev' into dev
shivammehta25 Jan 3, 2023
2abbc97
Merge branch 'dev' of github.com:shivammehta25/TTS into dev
shivammehta25 Jan 5, 2023
790b846
Adding a config flag to train neural HMM TTS instead of overflow
shivammehta25 Jan 9, 2023
a8d0b22
Backwards compatibility: Fixing model zoo if the flag is not set, set it
shivammehta25 Jan 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding all tests for overflow
  • Loading branch information
shivammehta25 committed Dec 9, 2022
commit c3d0167733e18b1a76989f67cbbf560f9f260c1d
2 changes: 1 addition & 1 deletion TTS/tts/layers/overflow/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def forward(self, ar_mels, inputs):

Args:
ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim)
states (torch.FloatTensor): (hidden_states, hidden_state_dim)
states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim)

Returns:
means: means for the emission observation for each feature
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/overflow/neural_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ def inference(
max_sampling_time: int,
duration_threshold: float,
):
"""Sampling from autoregressive neural HMM
TODO: Add support for batched inference
"""Inference from autoregressive neural HMM

Args:
inputs (torch.FloatTensor): input states
Expand Down Expand Up @@ -379,6 +378,7 @@ def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_
- shape: :math:`(1)`
sampling_temp (float): sampling temperature
max_sampling_time (int): max sampling time
duration_threshold (float): duration threshold to switch to next state

Returns:
outputs (torch.FloatTensor): Output Observations
Expand Down
170 changes: 159 additions & 11 deletions tests/tts_tests/test_overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from tests import get_tests_output_path
from TTS.tts.configs.overflow_config import OverflowConfig
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils
from TTS.tts.layers.overflow.decoder import Decoder
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel
from TTS.tts.models.overflow import Overflow
from TTS.tts.utils.helpers import sequence_mask
from TTS.utils.audio import AudioProcessor
Expand Down Expand Up @@ -71,7 +71,7 @@ def weight_reset(m):
model.apply(fn=weight_reset)


class TestOverFlow(unittest.TestCase):
class TestOverflow(unittest.TestCase):
def test_forward(self):
model = get_model()
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
Expand All @@ -93,7 +93,7 @@ def test_init_from_config(self):
self.assertEqual(model.prenet_dim, config.prenet_dim)


class TestOverFlowPrenet(unittest.TestCase):
class TestOverflowEncoder(unittest.TestCase):
@staticmethod
def get_encoder(state_per_phone):
config = deepcopy(config_global)
Expand All @@ -118,7 +118,7 @@ def test_inference_with_state_per_phone_multiplication(self):
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p)


class TestOverFlowUtils(unittest.TestCase):
class TestOverflowUtils(unittest.TestCase):
def test_logsumexp(self):
a = torch.randn(10) # random numbers
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
Expand All @@ -130,7 +130,7 @@ def test_logsumexp(self):
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())


class TestOverFlowDecoder(unittest.TestCase):
class TestOverflowDecoder(unittest.TestCase):
@staticmethod
def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True):
config = deepcopy(config_global)
Expand Down Expand Up @@ -166,12 +166,10 @@ def test_decoder_forward_backward(self):
z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths)
mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True)
mask = sequence_mask(z_len).unsqueeze(1)
print(mel_spec.shape, mask.shape)
mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask
z = z * mask
print(mel_spec[0], mel_spec_[0])
self.assertTrue(
torch.isclose(mel_spec, mel_spec_, atol=1e-3).all(),
torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(),
f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}",
)

Expand All @@ -197,6 +195,14 @@ def _get_neural_hmm(deterministic_transition=None):
).to(device)
return neural_hmm

@staticmethod
def _get_emission_model():
return EmissionModel().to(device)

@staticmethod
def _get_transition_model():
return TransitionModel().to(device)

@staticmethod
def _get_embedded_input():
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
Expand Down Expand Up @@ -247,5 +253,147 @@ def test_process_ar_timestep(self):
c_post_prenet,
)

assert h_post_prenet.shape == (input_dummy.shape[0], config_global.memory_rnn_dim)
assert c_post_prenet.shape == (input_dummy.shape[0], config_global.memory_rnn_dim)
self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))
self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))

def test_add_go_token(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()

out = model._add_go_token(mel_spec) # pylint: disable=protected-access
self.assertEqual(out.shape, mel_spec.shape)
self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly")

def test_forward_algorithm_variables(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()

(
log_c,
log_alpha_scaled,
transition_matrix,
_,
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
)

self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1]))
self.assertEqual(
log_alpha_scaled.shape,
(
mel_spec.shape[0],
mel_spec.shape[1],
input_dummy.shape[1] * config_global.state_per_phone,
),
)
self.assertEqual(
transition_matrix.shape,
(mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone),
)

def test_get_absorption_state_scaling_factor(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
input_lengths = input_lengths * config_global.state_per_phone
(
log_c,
log_alpha_scaled,
transition_matrix,
_,
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
)
log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3)
transition_matrix = torch.randn_like(transition_matrix).sigmoid().log()
sum_final_log_c = model.get_absorption_state_scaling_factor(
mel_lengths, log_alpha_scaled, input_lengths, transition_matrix
)

text_mask = ~sequence_mask(input_lengths)
transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device)

outputs = []

for i in range(input_dummy.shape[0]):
last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf"))
log_last_transition_probability = OverflowUtils.log_clamped(
torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1])
).masked_fill(transition_prob_mask[i], -float("inf"))
outputs.append(last_log_alpha_scaled + log_last_transition_probability)

sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1)

self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all())

def test_inference(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
for temp in [0.334, 0.667, 1.0]:
outputs = model.inference(
input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold
)
self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1])
self.assertEqual(
outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]
)
self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0])

def test_emission_model(self):
model = self._get_emission_model()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device)
means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device)
std = torch.rand_like(means).to(device).clamp_(1e-3) # std should be positive
out = model(x_t, means, std, input_lengths)
self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1]))

# testing sampling
for temp in [0, 0.334, 0.667]:
out = model.sample(means, std, 0)
self.assertEqual(out.shape, means.shape)
if temp == 0:
self.assertTrue(torch.isclose(out, means).all())

def test_transition_model(self):
model = self._get_transition_model()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device)
transition_vector = torch.randn(input_lengths.max()).to(device)
out = model(prev_t_log_scaled_alph, transition_vector, input_lengths)
self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max()))


class TestOverflowOutputNet(unittest.TestCase):
@staticmethod
def _get_outputnet():
config = deepcopy(config_global)
outputnet = Outputnet(
config.encoder_in_out_features,
config.memory_rnn_dim,
config.out_channels,
config.outputnet_size,
config.flat_start_params,
config.std_floor,
).to(device)
return outputnet

@staticmethod
def _get_embedded_input():
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)(
input_dummy
)
one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device)
return input_dummy, one_timestep_frame

def test_outputnet_forward_with_flat_start(self):
model = self._get_outputnet()
input_dummy, one_timestep_frame = self._get_embedded_input()
mean, std, transition_vector = model(one_timestep_frame, input_dummy)
self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all())
self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all())
self.assertTrue(
torch.isclose(
transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0)
).all()
)