Skip to content

Commit

Permalink
moving loss to the model file
Browse files Browse the repository at this point in the history
  • Loading branch information
shivammehta25 committed Dec 23, 2022
1 parent 3149b43 commit 8aff87a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
22 changes: 0 additions & 22 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,25 +872,3 @@ def forward(

return_dict["loss"] = loss
return return_dict


class NLLLoss(nn.Module):
"""Negative log likelihood loss."""

def __init__(self):
super().__init__()

def forward(self, log_prob: torch.Tensor) -> dict:
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = - log_prob.mean()
return return_dict

28 changes: 22 additions & 6 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from coqpit import Coqpit
from trainer.logging.tensorboard_logger import TensorboardLogger

from TTS.tts.layers.neural_hmm.common_layers import Encoder, OverFlowUtils
from TTS.tts.layers.neural_hmm.decoder import Decoder
from TTS.tts.layers.neural_hmm.hmm import NeuralHMM
from TTS.tts.layers.neural_hmm.plotting_utils import (
from TTS.tts.layers.overflow.common_layers import Encoder, OverFlowUtils
from TTS.tts.layers.overflow.decoder import Decoder
from TTS.tts.layers.overflow.neural_hmm import NeuralHMM
from TTS.tts.layers.overflow.plotting_utils import (
get_spec_from_most_probable_state,
plot_transition_probabilities_to_numpy,
)
Expand Down Expand Up @@ -225,8 +225,6 @@ def inference(

@staticmethod
def get_criterion():
from TTS.tts.layers.losses import NLLLoss # pylint: disable=import-outside-toplevel

return NLLLoss()

@staticmethod
Expand Down Expand Up @@ -342,3 +340,21 @@ def eval_log(

figures = self._create_logs(batch, outputs)
logger.eval_figures(steps, figures)


class NLLLoss(nn.Module):
"""Negative log likelihood loss."""

def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use
"""Compute the loss.
Args:
logits (Tensor): [B, T, D]
Returns:
Tensor: [1]
"""
return_dict = {}
return_dict["loss"] = -log_prob.mean()
return return_dict

0 comments on commit 8aff87a

Please sign in to comment.