From b5bc058e3c2b8095906e2e19a6f13c3dbdd327fb Mon Sep 17 00:00:00 2001 From: "Christian J. Steinmetz" Date: Wed, 19 Apr 2023 21:53:30 +0000 Subject: [PATCH] adding mag_distance parameter back --- auraloss/freq.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/auraloss/freq.py b/auraloss/freq.py index a1b3b59..a5efe70 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -110,6 +110,7 @@ def __init__( eps: float = 1e-8, output: str = "loss", reduction: str = "mean", + mag_distance: str = "L1", device: Any = None, ): super().__init__() @@ -129,19 +130,23 @@ def __init__( self.eps = eps self.output = output self.reduction = reduction + self.mag_distance = mag_distance self.device = device self.spectralconv = SpectralConvergenceLoss() self.logstft = STFTMagnitudeLoss( - log=True, reduction=reduction, distance=mag_distance + log=True, + reduction=reduction, + distance=mag_distance, ) self.linstft = STFTMagnitudeLoss( - log=False, reduction=reduction, distance=mag_distance + log=False, + reduction=reduction, + distance=mag_distance, ) # setup mel filterbank if scale is not None: - try: import librosa.filters except Exception as e: @@ -362,7 +367,7 @@ def __init__( scale_invariance: bool = False, **kwargs, ): - super(MultiResolutionSTFTLoss, self).__init__() + super().__init__() assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all self.fft_sizes = fft_sizes self.hop_sizes = hop_sizes @@ -451,7 +456,7 @@ def __init__( randomize_rate=1, **kwargs, ): - super(RandomResolutionSTFTLoss, self).__init__() + super().__init__() self.resolutions = resolutions self.min_fft_size = min_fft_size self.max_fft_size = max_fft_size @@ -558,7 +563,7 @@ def __init__( output: str = "loss", **kwargs, ): - super(SumAndDifferenceSTFTLoss, self).__init__() + super().__init__() self.sd = SumAndDifference() self.w_sum = w_sum self.w_diff = w_diff @@ -571,7 +576,7 @@ def __init__( **kwargs, ) - def forward(self, input, target): + def forward(self, input: torch.Tensor, target: torch.Tensor): """This loss function assumes batched input of stereo audio in the time domain. Args: