diff --git a/TTS/tts/layers/neural_hmm/__init__.py b/TTS/tts/layers/neural_hmm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/TTS/tts/layers/neural_hmm/common_layers.py b/TTS/tts/layers/neural_hmm/common_layers.py index 9801bc57dd..0846e19bb5 100644 --- a/TTS/tts/layers/neural_hmm/common_layers.py +++ b/TTS/tts/layers/neural_hmm/common_layers.py @@ -1,11 +1,17 @@ +from typing import List + +import torch import torch.nn as nn +import torch.nn.functional as F +from TTS.tts.layers.tacotron.common_layers import Linear from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock +from TTS.tts.utils.helpers import inverse_sigmod, inverse_softplus class Encoder(nn.Module): r"""Neural HMM Encoder - + Same as Tacotron 2 encoder but increases the states per phone Args: @@ -21,8 +27,7 @@ def __init__(self, state_per_phone, in_out_channels=512): self.state_per_phone = state_per_phone self.in_out_channels = in_out_channels - - + self.convolutions = nn.ModuleList() for _ in range(3): self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) @@ -32,10 +37,9 @@ def __init__(self, state_per_phone, in_out_channels=512): num_layers=1, batch_first=True, bias=True, - bidirectional=True + bidirectional=True, ) self.rnn_state = None - def forward(self, x, input_lengths): b, _, T = x.shape @@ -52,3 +56,146 @@ def forward(self, x, input_lengths): return o, T +class ParameterModel(nn.Module): + r"""Main neural network of the outputnet + + Note: Do not put dropout layers here, the model will not converge. + + Args: + parameternetwork (List[int]): the architecture of the parameter model + input_size (int): size of input for the first layer + output_size (int): size of output i.e size of the feature dim + frame_channels (int): feature dim to set the flat start bias + init_transition_probability (float): flat start transition probability + init_mean (float): flat start mean + init_std (float): flat start std + """ + + def __init__( + self, + parameternetwork: List[int], + input_size: int, + output_size: int, + flat_start_params: dict, + frame_channels: int, + ): + super().__init__() + self.flat_start_params = flat_start_params + + self.layers = nn.ModuleList( + [Linear(inp, out) for inp, out in zip([input_size] + parameternetwork[:-1], parameternetwork)] + ) + last_layer = self._flat_start_output_layer(parameternetwork[-1], output_size, frame_channels) + self.layers.append(last_layer) + + def _flat_start_output_layer(self, input_size, output_size, frame_channels): + last_layer = nn.Linear(input_size, output_size) + last_layer.weight.data.zero_() + last_layer.bias.data[0:frame_channels] = self.flat_start_params["mean"] + last_layer.bias.data[frame_channels : 2 * frame_channels] = inverse_softplus(self.flat_start_params["std"]) + last_layer.bias.data[2 * frame_channels :] = inverse_sigmod(self.flat_start_params["transition_p"]) + return last_layer + + def forward(self, x): + for layer in self.layers[:-1]: + x = F.relu(layer(x)) + x = self.layers[-1](x) + return x + + +class Outputnet(nn.Module): + r""" + This network takes current state and previous observed values as input + and returns its parameters, mean, standard deviation and probability + of transition to the next state + """ + + def __init__( + self, + encoder_dim: int, + memory_rnn_dim: int, + frame_channels: int, + parameternetwork: List[int], + flat_start_params: dict, + std_floor: float = 1e-2, + ): + super().__init__() + + self.frame_channels = frame_channels + self.flat_start_params = flat_start_params + self.std_floor = std_floor + + input_size = memory_rnn_dim + encoder_dim + output_size = 2 * frame_channels + 1 + + self._validate_parameters() + + self.parametermodel = ParameterModel( + parameternetwork=parameternetwork, + input_size=input_size, + output_size=output_size, + flat_start_params=flat_start_params, + frame_channels=frame_channels, + ) + + def _validate_parameters(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert ( + self.parameternetwork >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" + + def forward(self, ar_mels, inputs): + r"""Inputs observation and returns the means, stds and transition probability for the current state + + Args: + ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) + states (torch.FloatTensor): (hidden_states, hidden_state_dim) + + Returns: + means: means for the emission observation for each feature + - shape: (B, hidden_states, feature_size) + stds: standard deviations for the emission observation for each feature + - shape: (batch, hidden_states, feature_size) + transition_vectors: transition vector for the current hidden state + - shape: (batch, hidden_states) + """ + batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] + N = inputs.shape[1] + + ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) + ar_mels = torch.cat((ar_mels, inputs), dim=2) + ar_mels = self.parametermodel(ar_mels) + + mean, std, transition_vector = ( + ar_mels[:, :, 0 : self.frame_channels], + ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], + ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), + ) + std = F.softplus(std) + std = self._floor_std(std) + return mean, std, transition_vector + + def _floor_std(self, std): + r""" + It clamps the standard deviation to not to go below some level + This removes the problem when the model tries to cheat for higher likelihoods by converting + one of the gaussians to a point mass. + + Args: + std (float Tensor): tensor containing the standard deviation to be + """ + original_tensor = std.clone().detach() + std = torch.clamp(std, min=self.std_floor) + if torch.any(original_tensor != std): + print( + "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + ) + return std diff --git a/TTS/tts/layers/neural_hmm/hmm.py b/TTS/tts/layers/neural_hmm/hmm.py index 594d8e9791..6a9baae152 100644 --- a/TTS/tts/layers/neural_hmm/hmm.py +++ b/TTS/tts/layers/neural_hmm/hmm.py @@ -1,67 +1,94 @@ from dataclasses import dataclass +from typing import List import torch import torch.distributions as tdist import torch.nn as nn -import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from TTS.tts.layers.neural_hmm.common_layers import Outputnet from TTS.tts.layers.tacotron.common_layers import Prenet from TTS.tts.utils.helpers import log_clamped, logsumexp, sequence_mask class HMM(nn.Module): - def __init__(self, - frame_channels, - ar_order, - prenet_type, - prenet_dim, - prenet_dropout, - memory_rnn_dim, - prenet_dropout_while_eval, - ar_data_dropout=False, + """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)" + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using + HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase + training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right + no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an + HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without + approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting + example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst + achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate. + """ + + def __init__( + self, + frame_channels: int, + ar_order: int, + encoder_dim: int, + prenet_type: str, + prenet_dim: int, + prenet_dropout: float, + memory_rnn_dim: int, + prenet_dropout_at_inference: bool, + parameternetwork: List[int], + flat_start_params: dict, + std_floor: float, ): super().__init__() - + + self.frame_channels = frame_channels + self.ar_order = ar_order + self.memory_rnn_dim = memory_rnn_dim + self.transition_model = TransitionModel() self.emission_model = EmissionModel() - + assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}" - + self.ar_order = ar_order self.prenet = Prenet( in_features=frame_channels * ar_order, prenet_type=prenet_type, prenet_dropout=prenet_dropout, + dropout_at_inference=prenet_dropout_at_inference, out_features=[self.prenet_dim, self.prenet_dim], - bias=False + bias=False, ) self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim) - self.output_net = OutputNet(hparams) + self.output_net = Outputnet(encoder_dim, memory_rnn_dim, frame_channels, parameternetwork, flat_start_params, std_floor) self.register_buffer("go_tokens", torch.zeros(ar_order, 1)) def forward(self, inputs, inputs_len, mels, mel_lens): - r""" - HMM forward algorithm for training + r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm. Args: - text_embs (torch.FloatTensor): Encoder outputs - text_lens (torch.LongTensor): Encoder output lengths + inputs (torch.FloatTensor): Encoder outputs + inputs_len (torch.LongTensor): Encoder output lengths mels (torch.FloatTensor): Mel inputs for teacher forcing mel_lens (torch.LongTensor): Length of mel inputs - + Shapes: - - text_embs: (batch, C, T) - - + - inputs: (B, D_out_enc, T) + - inputs_len: (B) + - mels: (B, T_mel, D_mel) + - mel_lens: (B) Returns: log_prob (torch.FloatTensor): Log probability of the sequence """ - # Get dimensions of inputs - batch_size, self.N = mels.shape[0] + batch_size, self.N = inputs.shape T_max = torch.max(mel_lens) - self.N = inputs.shape[1] - mels = mels.permute(0, 2, 1) #! TODO: check dataloader here + mels = mels.permute(0, 2, 1) #! TODO: check dataloader here # Intialize forward algorithm log_state_priors = self._initialize_log_state_priors(inputs) @@ -69,49 +96,37 @@ def forward(self, inputs, inputs_len, mels, mel_lens): # Initialize autoregression elements ar_inputs = self._add_go_token(mels) - h_post_prenet, c_post_prenet = self.init_lstm_states(batch_size, self.hparams.post_prenet_rnn_dim, mels) + h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels) for t in range(T_max): # Process Autoregression - h_post_prenet, c_post_prenet = self.process_ar_timestep( - t, - ar_inputs, - h_post_prenet, - c_post_prenet, - data_dropout_flag, - prenet_dropout_flag, - ) - + h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory) # Get mean, std and transition vector from decoder for this timestep - mean, std, transition_vector = self.decoder(h_post_prenet, inputs) - - # Forward algorithm for this timestep + # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) if t == 0: log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len) else: log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model( self.log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len ) - log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1) self.log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1) + self.transition_vector[:, t] = transition_vector # needed for absorption state calculation # Save for plotting - self.transition_vector[:, t] = transition_vector.detach() self.means.append(mean.detach()) - log_c = self.mask_lengths(mels, mel_lens, log_c) + log_c = self._mask_lengths(mels, mel_lens, log_c) - sum_final_log_c = self.get_absorption_state_scaling_factor( - mel_lens, self.log_alpha_scaled, inputs_len - ) + sum_final_log_c = self.get_absorption_state_scaling_factor(mel_lens, self.log_alpha_scaled, inputs_len) log_probs = torch.sum(log_c, dim=1) + sum_final_log_c return log_probs - def mask_lengths(self, mel_inputs, mel_inputs_lengths, log_c): + def _mask_lengths(self, mel_inputs_lengths, log_c): """ Mask the lengths of the forward variables so that the variable lenghts do not contribute in the loss calculation @@ -122,29 +137,18 @@ def mask_lengths(self, mel_inputs, mel_inputs_lengths, log_c): Returns: log_c (torch.FloatTensor) : scaled probabilities (batch, T) """ - batch_size, T, n_mel_channel = mel_inputs.shape - # create len mask - mask_tensor = mel_inputs.new_zeros(T) - mask_log_c = torch.arange(T, out=mask_tensor).expand(len(mel_inputs_lengths), T) < ( - mel_inputs_lengths - ).unsqueeze(1) - # mask log_c + mask_log_c = sequence_mask(mel_inputs_lengths) log_c = log_c * mask_log_c - - # mask log_alpha_scaled mask_log_alpha_scaled = mask_log_c.unsqueeze(2) self.log_alpha_scaled = self.log_alpha_scaled * mask_log_alpha_scaled - return log_c - def process_ar_timestep( + def _process_ar_timestep( self, t, ar_inputs, - h_post_prenet, - c_post_prenet, - data_dropout_flag, - prenet_dropout_flag, + h_memory, + c_memory, ): """ Process autoregression in timestep @@ -165,12 +169,10 @@ def process_ar_timestep( h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep """ - prenet_input = self.perform_data_dropout_of_ar_mel_inputs( - ar_inputs[:, t : t + self.hparams.ar_order], data_dropout_flag - ) - ar_inputs_prenet = self.prenet(prenet_input.flatten(1), prenet_dropout_flag) - h_post_prenet, c_post_prenet = self.post_prenet_rnn(ar_inputs_prenet, (h_post_prenet, c_post_prenet)) - return h_post_prenet, c_post_prenet + prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1) + memory_inputs = self.prenet(prenet_input) + h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory)) + return h_memory, c_memory def _add_go_token(self, mel_inputs): """Append the go token to create the autoregressive input @@ -180,9 +182,7 @@ def _add_go_token(self, mel_inputs): ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) """ batch_size, T, _ = mel_inputs.shape - go_tokens = self.go_tokens.unsqueeze(0).expand( - batch_size, self.ar_order, self.frame_channels - ) + go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels) ar_inputs = torch.cat((go_tokens, mel_inputs), dim=1)[:, :T] return ar_inputs @@ -204,7 +204,7 @@ def _initialize_forward_algorithm_variables(self, mel_inputs): self.transition_vector = mel_inputs.new_zeros((batch_size, T_max, self.N)) return log_c - def init_lstm_states(self, batch_size, hidden_state_dim, device_tensor): + def _init_lstm_states(self, batch_size, hidden_state_dim, device_tensor): r""" Initialize Hidden and Cell states for LSTM Cell @@ -224,46 +224,7 @@ def init_lstm_states(self, batch_size, hidden_state_dim, device_tensor): device_tensor.new_zeros(batch_size, hidden_state_dim), ) - def perform_data_dropout_of_ar_mel_inputs(self, mel_inputs, dropout_flag): - r""" - Takes mel frames as inputs and applies data dropout on it - - Args: - mel_inputs (torch.FloatTensor): mel frames - shape: (batch, ar_order, frame_channels) - dropout_flag (Bool): to dropout during eval or not - - Returns: - mel_inputs (torch.FloatTensor) : data droped out mel frames - """ - batch_size, n_frame_per_step, _ = mel_inputs.shape - - data_dropout_mask = F.dropout( - mel_inputs.new_ones(batch_size, n_frame_per_step), - p=self.hparams.data_dropout, - training=dropout_flag, - ).unsqueeze(2) - mel_inputs = mel_inputs * data_dropout_mask - return mel_inputs - - def get_dropout_while_eval(self, dropout_while_eval: bool) -> bool: - r""" - Returns the flag to be true or false based on the value given during evaluation - - Args: - dropout_while_eval (bool): flag to dropout while eval or not - - Returns: - (bool): dropout flag - """ - if dropout_while_eval: - dropout_flag = True - else: - dropout_flag = self.training - - return dropout_flag - - def get_absorption_state_scaling_factor(self, mel_inputs_lengths, log_alpha_scaled, text_lengths): + def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len): r""" Returns the final scaling factor of absorption state Args: @@ -284,14 +245,11 @@ def get_absorption_state_scaling_factor(self, mel_inputs_lengths, log_alpha_scal Returns: """ - max_text_len = log_alpha_scaled.shape[2] - mask_tensor = log_alpha_scaled.new_zeros(max_text_len) - state_lengths_mask = torch.arange(max_text_len, out=mask_tensor).expand(len(text_lengths), max_text_len) < ( - text_lengths - ).unsqueeze(1) + max_inputs_len = log_alpha_scaled.shape[2] + state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len) last_log_alpha_scaled_index = ( - (mel_inputs_lengths - 1).unsqueeze(-1).expand(-1, self.N).unsqueeze(1) + (mels_len - 1).unsqueeze(-1).expand(-1, self.N).unsqueeze(1) ) # Batch X Hidden State Size last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1) last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf")) @@ -301,28 +259,29 @@ def get_absorption_state_scaling_factor(self, mel_inputs_lengths, log_alpha_scal log_probability_of_transitioning = log_clamped(last_transition_probability) last_transition_probability_index = ( - torch.arange(max_text_len, out=mask_tensor).expand(len(text_lengths), max_text_len) - ) == (text_lengths - 1).unsqueeze(1) + torch.arange(max_inputs_len, dtype=inputs_len.dtype, device=inputs_len.device).expand( + len(inputs_len), max_inputs_len + ) + ) == (inputs_len - 1).unsqueeze(1) log_probability_of_transitioning = log_probability_of_transitioning.masked_fill( ~last_transition_probability_index, -float("inf") ) - final_log_c = last_log_alpha_scaled + log_probability_of_transitioning - # Uncomment the line below if you get nan values because of low precision - # in half precision training + # Uncomment the line below if you get nan values because of low precisin in half precision training # final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min) sum_final_log_c = torch.logsumexp(final_log_c, dim=1) return sum_final_log_c @torch.inference_mode() - def sample(self, encoder_outputs, sampling_temp=1.0, T=None): + def sample(self, inputs, sampling_temp=1.0, T=None): r""" Samples an output from the parameter models Args: encoder_outputs (float tensor): (batch, text_len, encoder_embedding_dim) + sampling_temp T (int): Max time to sample Returns: @@ -330,16 +289,10 @@ def sample(self, encoder_outputs, sampling_temp=1.0, T=None): z (list[int]): Hidden states travelled """ if not T: - T = self.hparams.max_sampling_time - - self.N = encoder_outputs.shape[1] - 1 - if self.hparams.ar_order > 0: - ar_mel_inputs = self.go_tokens.unsqueeze(0) - else: - raise ValueError( - "ar_order should be greater than 0, \ - it is an Autoregressive model" - ) + T = self.max_sampling_time + + self.N = inputs.shape[1] + prenet_input = self.go_tokens.unsqueeze(0) z, x = [], [] t = 0 @@ -348,50 +301,38 @@ def sample(self, encoder_outputs, sampling_temp=1.0, T=None): current_z_number = 0 z.append(current_z_number) - h_post_prenet, c_post_prenet = self.init_lstm_states(1, self.hparams.post_prenet_rnn_dim, ar_mel_inputs) + h_memory, c_memory = self._init_lstm_states(1, self.post_prenet_rnn_dim, prenet_input) input_parameter_values = [] output_parameter_values = [] quantile = 1 - prenet_dropout_flag = self.get_dropout_while_eval(self.hparams.prenet_dropout_while_eval) - while True: - if self.hparams.data_dropout_while_sampling: - dropout_mask = F.dropout( - ar_mel_inputs.new_ones(ar_mel_inputs.shape[0], ar_mel_inputs.shape[1], 1), - p=self.hparams.data_dropout, - ) - ar_mel_inputs = dropout_mask * ar_mel_inputs - - prenet_output = self.prenet(ar_mel_inputs.flatten(1).unsqueeze(0), prenet_dropout_flag) + memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0)) # will be 1 while sampling - h_post_prenet, c_post_prenet = self.post_prenet_rnn( - prenet_output.squeeze(0), (h_post_prenet, c_post_prenet) - ) + h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory)) - z_t = encoder_outputs[:, current_z_number] - mean, std, transition_vector = self.decoder(h_post_prenet, z_t.unsqueeze(0)) + z_t = inputs[:, current_z_number] + mean, std, transition_vector = self.decoder(h_memory, z_t.unsqueeze(0)) transition_probability = torch.sigmoid(transition_vector.flatten()) staying_probability = torch.sigmoid(-transition_vector.flatten()) input_parameter_values.append([ar_mel_inputs, current_z_number]) output_parameter_values.append([mean, std, transition_probability]) - x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) - - if self.hparams.predict_means: + if self.predict_means: x_t = mean + else: + x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) ar_mel_inputs = torch.cat((ar_mel_inputs, x_t), dim=1)[:, 1:] x.append(x_t.flatten()) transition_matrix = torch.cat((staying_probability, transition_probability)) quantile *= staying_probability - - if not self.hparams.deterministic_transition: + if not self.deterministic_transition: switch = transition_matrix.multinomial(1)[0].item() else: - switch = quantile < self.hparams.duration_quantile_threshold + switch = quantile < self.duration_quantile_threshold if switch: current_z_number += 1 @@ -421,20 +362,20 @@ def _initialize_log_state_priors(self, text_embeddings): @dataclass class TransitionModel(nn.Module): """Transition Model of the HMM, it represents the probability of transitioning - form current state to all other states""" + form current state to all other states""" - staying_probability: torch.FloatTensor = None - transition_probability: torch.FloatTensor = None + staying_p: torch.FloatTensor = None + transition_pr: torch.FloatTensor = None - def update_current_values(self, staying: torch.FloatTensor, transitioning: torch.FloatTensor): + def _update_current_values(self, staying: torch.FloatTensor, transitioning: torch.FloatTensor): """ Make reference of the staying and transitioning probabilities as instance parameters of class """ - self.staying_probability = staying - self.transition_probability = transitioning + self.staying_p = staying + self.transition_pr = transitioning - def forward(self, log_alpha_scaled, transition_vector, state_lengths): + def forward(self, log_alpha_scaled, transition_vector, inputs_len): r""" product of the past state with transitional probabilities in log space @@ -444,43 +385,34 @@ def forward(self, log_alpha_scaled, transition_vector, state_lengths): - shape: (batch size, N) transition_vector (torch.tensor): transition vector for each state - shape: (N) - state_lengths (int tensor): Lengths of states in a batch + inputs_len (int tensor): Lengths of states in a batch - shape: (batch) Returns: out (torch.FloatTensor): log probability of transitioning to each state """ - T_max = log_alpha_scaled.shape[1] - - transition_probability = torch.sigmoid(transition_vector) - staying_probability = torch.sigmoid(-transition_vector) + transition_p = torch.sigmoid(transition_vector) + staying_p = torch.sigmoid(-transition_vector) - self.update_current_values(staying_probability, transition_probability) + self._update_current_values(staying_p, transition_p) - log_staying_probability = log_clamped(staying_probability) - log_transition_probability = log_clamped(transition_probability) + log_staying_probability = log_clamped(staying_p) + log_transition_probability = log_clamped(transition_p) staying = log_alpha_scaled + log_staying_probability leaving = log_alpha_scaled + log_transition_probability leaving = leaving.roll(1, dims=1) leaving[:, 0] = -float("inf") - - mask_tensor = log_alpha_scaled.new_zeros(T_max) - not_state_lengths_mask = ~( - torch.arange(T_max, out=mask_tensor).expand(len(state_lengths), T_max) < (state_lengths).unsqueeze(1) - ) - + inputs_len_mask = sequence_mask(inputs_len) out = logsumexp(torch.stack((staying, leaving), dim=2), dim=2) - - out = out.masked_fill(not_state_lengths_mask, -float("inf")) - + out = out.masked_fill(~inputs_len_mask, -float("inf")) # There are no states to contribute to the loss return out @dataclass class EmissionModel(nn.Module): """Emission Model of the HMM, it represents the probability of - emitting an observation based on the current state""" + emitting an observation based on the current state""" distribution_function: tdist.Distribution = tdist.normal.Normal @@ -510,4 +442,4 @@ def forward(self, x_t, means, stds, state_lengths): out = emission_dists.log_prob(x_t.unsqueeze(1)) state_lengths_mask = sequence_mask(state_lengths) out = torch.sum(out * state_lengths_mask, dim=2) - return out \ No newline at end of file + return out