Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jan 22, 2021
2 parents 9c1b322 + b70bef5 commit 5ee73c2
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions TTS/vocoder/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from torch.nn import functional as F


class TorchSTFT():
class TorchSTFT(nn.Module):
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
""" Torch based STFT operation """
super(TorchSTFT, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.window = nn.Parameter(getattr(torch, window)(win_length),
requires_grad=False)

def __call__(self, x):
# B x D x T x 2
Expand All @@ -22,7 +24,8 @@ def __call__(self, x):
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True)
onesided=True,
return_complex=False)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
Expand Down

0 comments on commit 5ee73c2

Please sign in to comment.