Skip to content

Commit

Permalink
added some bwe-related stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Apr 24, 2024
1 parent 5667867 commit 0dc559f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
34 changes: 34 additions & 0 deletions dnn/torch/osce/losses/td_lowpass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import scipy.signal


from utils.layers.fir import FIR

class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()

self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power

def forward(self, y_true, y_pred):

assert len(y_true.shape) == 3 and len(y_pred.shape) == 3

diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)

loss = torch.mean(torch.abs(diff_lp ** self.power))

return loss, diff_lp

def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)

return freq, response





28 changes: 28 additions & 0 deletions dnn/torch/osce/silk_16_to_48.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse

from scipy.io import wavfile
import torch
import numpy as np

from utils.layers.silk_upsampler import SilkUpsampler

parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="input wave file")
parser.add_argument("output", type=str, help="output wave file")

if __name__ == "__main__":
args = parser.parse_args()

fs, x = wavfile.read(args.input)

# being lazy for now
assert fs == 16000 and x.dtype == np.int16

x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)

upsampler = SilkUpsampler()
y = upsampler(x)

y = y.squeeze().numpy().astype(np.int16)

wavfile.write(args.output, 48000, y[13:])
27 changes: 27 additions & 0 deletions dnn/torch/osce/utils/layers/fir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F


class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()

if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1

a = scipy.signal.firls(numtaps, bands, desired, fs=fs)

self.weight = torch.from_numpy(a.astype(np.float32))

def forward(self, x):
num_channels = x.size(1)

weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)

y = F.conv1d(x, weight, groups=num_channels)

return y

0 comments on commit 0dc559f

Please sign in to comment.