Skip to content

Make jit compilation optional for function and use nn.Module #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
import librosa


def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)

jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)

assert torch.allclose(jit_out, py_out)


class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100
Expand All @@ -25,6 +34,21 @@ class TestFunctional(unittest.TestCase):
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')

def test_torchscript_spectrogram(self):

tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False

_test_torchscript_functional(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)

def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
Expand All @@ -49,6 +73,7 @@ def test_compute_deltas_randn(self):
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)

def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
Expand All @@ -63,6 +88,7 @@ def test_batch_pitch(self):

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)

def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
Expand Down Expand Up @@ -424,6 +450,74 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):

assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)

def test_torchscript_create_fb_matrix(self):

n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000

_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)

def test_torchscript_amplitude_to_DB(self):

spec = torch.rand((6, 201))
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0

_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)

def test_torchscript_create_dct(self):

n_mfcc = 40
n_mels = 128
norm = "ortho"

_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)

def test_torchscript_mu_law_encoding(self):

tensor = torch.rand((1, 10))
qc = 256

_test_torchscript_functional(F.mu_law_encoding, tensor, qc)

def test_torchscript_mu_law_decoding(self):

tensor = torch.rand((1, 10))
qc = 256

_test_torchscript_functional(F.mu_law_decoding, tensor, qc)

def test_torchscript_complex_norm(self):

complex_tensor = torch.randn(1, 2, 1025, 400, 2),
power = 2

_test_torchscript_functional(F.complex_norm, complex_tensor, power)

def test_mask_along_axis(self):

specgram = torch.randn(2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2

_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)

def test_mask_along_axis_iid(self):

specgram = torch.randn(2, 1025, 400),
specgrams = torch.randn(4, 2, 1025, 400),
mask_param = 100
mask_value = 30.
axis = 2

_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)


@pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2),
Expand Down
16 changes: 16 additions & 0 deletions test/test_functional_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
import time


def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)

jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)

assert torch.allclose(jit_out, py_out)


class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()

Expand Down Expand Up @@ -79,6 +88,7 @@ def _test_lfilter(self, waveform, device):
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)

def test_lfilter(self):

Expand Down Expand Up @@ -116,6 +126,7 @@ def test_lowpass(self):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)

def test_highpass(self):
"""
Expand All @@ -135,6 +146,7 @@ def test_highpass(self):

# TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)

def test_equalizer(self):
"""
Expand All @@ -155,6 +167,7 @@ def test_equalizer(self):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)

assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)

def test_perf_biquad_filtering(self):

Expand Down Expand Up @@ -183,6 +196,9 @@ def test_perf_biquad_filtering(self):
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering

assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)


if __name__ == "__main__":
Expand Down
175 changes: 0 additions & 175 deletions test/test_jit.py

This file was deleted.

Loading