-
Notifications
You must be signed in to change notification settings - Fork 619
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jan Buethe
committed
Apr 24, 2024
1 parent
5667867
commit 0dc559f
Showing
3 changed files
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import scipy.signal | ||
|
||
|
||
from utils.layers.fir import FIR | ||
|
||
class TDLowpass(torch.nn.Module): | ||
def __init__(self, numtaps, cutoff, power=2): | ||
super().__init__() | ||
|
||
self.b = scipy.signal.firwin(numtaps, cutoff) | ||
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1) | ||
self.power = power | ||
|
||
def forward(self, y_true, y_pred): | ||
|
||
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3 | ||
|
||
diff = y_true - y_pred | ||
diff_lp = torch.nn.functional.conv1d(diff, self.weight) | ||
|
||
loss = torch.mean(torch.abs(diff_lp ** self.power)) | ||
|
||
return loss, diff_lp | ||
|
||
def get_freqz(self): | ||
freq, response = scipy.signal.freqz(self.b) | ||
|
||
return freq, response | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import argparse | ||
|
||
from scipy.io import wavfile | ||
import torch | ||
import numpy as np | ||
|
||
from utils.layers.silk_upsampler import SilkUpsampler | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("input", type=str, help="input wave file") | ||
parser.add_argument("output", type=str, help="output wave file") | ||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
fs, x = wavfile.read(args.input) | ||
|
||
# being lazy for now | ||
assert fs == 16000 and x.dtype == np.int16 | ||
|
||
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1) | ||
|
||
upsampler = SilkUpsampler() | ||
y = upsampler(x) | ||
|
||
y = y.squeeze().numpy().astype(np.int16) | ||
|
||
wavfile.write(args.output, 48000, y[13:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
import scipy.signal | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class FIR(nn.Module): | ||
def __init__(self, numtaps, bands, desired, fs=2): | ||
super().__init__() | ||
|
||
if numtaps % 2 == 0: | ||
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}") | ||
numtaps += 1 | ||
|
||
a = scipy.signal.firls(numtaps, bands, desired, fs=fs) | ||
|
||
self.weight = torch.from_numpy(a.astype(np.float32)) | ||
|
||
def forward(self, x): | ||
num_channels = x.size(1) | ||
|
||
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0) | ||
|
||
y = F.conv1d(x, weight, groups=num_channels) | ||
|
||
return y |