Skip to content

Commit

Permalink
adding mag_distance parameter back
Browse files Browse the repository at this point in the history
  • Loading branch information
csteinmetz1 committed Apr 19, 2023
1 parent 9be46b6 commit b5bc058
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
eps: float = 1e-8,
output: str = "loss",
reduction: str = "mean",
mag_distance: str = "L1",
device: Any = None,
):
super().__init__()
Expand All @@ -129,19 +130,23 @@ def __init__(
self.eps = eps
self.output = output
self.reduction = reduction
self.mag_distance = mag_distance
self.device = device

self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(
log=True, reduction=reduction, distance=mag_distance
log=True,
reduction=reduction,
distance=mag_distance,
)
self.linstft = STFTMagnitudeLoss(
log=False, reduction=reduction, distance=mag_distance
log=False,
reduction=reduction,
distance=mag_distance,
)

# setup mel filterbank
if scale is not None:

try:
import librosa.filters
except Exception as e:
Expand Down Expand Up @@ -362,7 +367,7 @@ def __init__(
scale_invariance: bool = False,
**kwargs,
):
super(MultiResolutionSTFTLoss, self).__init__()
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
self.fft_sizes = fft_sizes
self.hop_sizes = hop_sizes
Expand Down Expand Up @@ -451,7 +456,7 @@ def __init__(
randomize_rate=1,
**kwargs,
):
super(RandomResolutionSTFTLoss, self).__init__()
super().__init__()
self.resolutions = resolutions
self.min_fft_size = min_fft_size
self.max_fft_size = max_fft_size
Expand Down Expand Up @@ -558,7 +563,7 @@ def __init__(
output: str = "loss",
**kwargs,
):
super(SumAndDifferenceSTFTLoss, self).__init__()
super().__init__()
self.sd = SumAndDifference()
self.w_sum = w_sum
self.w_diff = w_diff
Expand All @@ -571,7 +576,7 @@ def __init__(
**kwargs,
)

def forward(self, input, target):
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""This loss function assumes batched input of stereo audio in the time domain.
Args:
Expand Down

0 comments on commit b5bc058

Please sign in to comment.