-
Notifications
You must be signed in to change notification settings - Fork 696
biquad filter similar to SoX #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
3d7a6e3
Add basic low pass filtering
engineerchuan 5202837
Naive first implementation of biquad per SoX implementation
engineerchuan b98522e
Add highpass filtering
engineerchuan 55a904c
added cpp implementation of filtering
engineerchuan 575339b
More tests of IIR vs FIR
engineerchuan a791792
improved performance by utilizing floats
engineerchuan b8b9eff
remove extraneous Python implementations
engineerchuan 73ae473
Implement convolve function, add tests
engineerchuan ca3afeb
Move lfilter and convolve into functional, more tests
engineerchuan c19f32d
Slight reformatting of functional and test
engineerchuan 54e3834
added additional documentation for convolve and lfilter, renamed func…
engineerchuan c5534c2
Fix documentation per issue https://github.com/pytorch/audio/issues/98
engineerchuan 92e309b
Merge https://github.com/pytorch/audio
engineerchuan ae5de6a
Delete sox_convenience, delete convolve wrapper
engineerchuan ebf3d9e
Unwind other issue to separate PR, clean up
engineerchuan 361a226
Follow naming convention for sample rate in functional
engineerchuan 5522ce5
Merge https://github.com/pytorch/audio
engineerchuan 288aedc
use memset instead of implicit initialization
engineerchuan 4115a14
fix failing vctk manifest test to account for adding more test audios…
engineerchuan 5469c4d
Trying out a tensor slicing based approach
engineerchuan 59cb5ce
Adding documentation for lfilter, biquad, highpass_biquad, lowpass_bi…
engineerchuan 388a0c0
added matrix based implementation of lfilter
engineerchuan 98b9ec1
adding python lfilter implementation
engineerchuan d09f34f
Merge https://github.com/pytorch/audio
engineerchuan f4c84c8
Unwind cpp implementations for a separate PR
engineerchuan 866de47
factor out biquad, lowpass, highpass to sox compatibility
engineerchuan 3336294
Adding documentation for functional sox compatibility
engineerchuan 9608edd
removed some print statements, mainly to kick build tests again
engineerchuan 445d717
moved previous sox compatibility into functional
engineerchuan ba317d0
Merge https://github.com/pytorch/audio
engineerchuan c54991a
Fix module reference in biquad
engineerchuan a4cd66d
Small wording tweaks to test docstrings
engineerchuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
import math | ||
import os | ||
import torch | ||
import torchaudio | ||
import torchaudio.functional as F | ||
import unittest | ||
import common_utils | ||
import time | ||
|
||
|
||
class TestFunctionalFiltering(unittest.TestCase): | ||
test_dirpath, test_dir = common_utils.create_temp_assets_dir() | ||
|
||
def test_lfilter_basic(self): | ||
""" | ||
Create a very basic signal, | ||
Then make a simple 4th order delay | ||
The output should be same as the input but shifted | ||
""" | ||
|
||
torch.random.manual_seed(42) | ||
waveform = torch.rand(2, 44100 * 10) | ||
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=torch.float32) | ||
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=torch.float32) | ||
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) | ||
|
||
assert torch.allclose( | ||
waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5 | ||
) | ||
|
||
def test_lfilter(self): | ||
""" | ||
Design an IIR lowpass filter using scipy.signal filter design | ||
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign | ||
|
||
Example | ||
>>> from scipy.signal import iirdesign | ||
>>> b, a = iirdesign(0.2, 0.3, 1, 60) | ||
""" | ||
|
||
b_coeffs = torch.tensor( | ||
[ | ||
0.00299893, | ||
-0.0051152, | ||
0.00841964, | ||
-0.00747802, | ||
0.00841964, | ||
-0.0051152, | ||
0.00299893, | ||
] | ||
) | ||
a_coeffs = torch.tensor( | ||
[ | ||
1.0, | ||
-4.8155751, | ||
10.2217618, | ||
-12.14481273, | ||
8.49018171, | ||
-3.3066882, | ||
0.56088705, | ||
] | ||
) | ||
|
||
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3") | ||
waveform, sample_rate = torchaudio.load(filepath, normalization=True) | ||
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) | ||
assert len(output_waveform.size()) == 2 | ||
assert output_waveform.size(0) == waveform.size(0) | ||
assert output_waveform.size(1) == waveform.size(1) | ||
|
||
def test_lowpass(self): | ||
|
||
""" | ||
Test biquad lowpass filter, compare to SoX implementation | ||
""" | ||
|
||
CUTOFF_FREQ = 3000 | ||
|
||
noise_filepath = os.path.join( | ||
self.test_dirpath, "assets", "whitenoise.mp3" | ||
) | ||
E = torchaudio.sox_effects.SoxEffectsChain() | ||
E.set_input_file(noise_filepath) | ||
E.append_effect_to_chain("lowpass", [CUTOFF_FREQ]) | ||
sox_output_waveform, sr = E.sox_build_flow_effects() | ||
|
||
waveform, sample_rate = torchaudio.load( | ||
noise_filepath, normalization=True | ||
) | ||
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) | ||
|
||
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) | ||
|
||
def test_highpass(self): | ||
""" | ||
Test biquad highpass filter, compare to SoX implementation | ||
""" | ||
|
||
CUTOFF_FREQ = 2000 | ||
|
||
noise_filepath = os.path.join( | ||
self.test_dirpath, "assets", "whitenoise.mp3" | ||
) | ||
E = torchaudio.sox_effects.SoxEffectsChain() | ||
E.set_input_file(noise_filepath) | ||
E.append_effect_to_chain("highpass", [CUTOFF_FREQ]) | ||
sox_output_waveform, sr = E.sox_build_flow_effects() | ||
|
||
waveform, sample_rate = torchaudio.load( | ||
noise_filepath, normalization=True | ||
) | ||
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ) | ||
|
||
# TBD - this fails at the 1e-4 level, debug why | ||
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) | ||
|
||
def test_perf_biquad_filtering(self): | ||
|
||
fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3") | ||
|
||
b0 = 0.4 | ||
b1 = 0.2 | ||
b2 = 0.9 | ||
a0 = 0.7 | ||
a1 = 0.2 | ||
a2 = 0.6 | ||
|
||
# SoX method | ||
E = torchaudio.sox_effects.SoxEffectsChain() | ||
E.set_input_file(fn_sine) | ||
_timing_sox = time.time() | ||
E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2]) | ||
waveform_sox_out, sr = E.sox_build_flow_effects() | ||
_timing_sox_run_time = time.time() - _timing_sox | ||
|
||
_timing_lfilter_filtering = time.time() | ||
waveform, sample_rate = torchaudio.load(fn_sine, normalization=True) | ||
waveform_lfilter_out = F.lfilter( | ||
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) | ||
) | ||
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering | ||
|
||
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.