Skip to content

Commit a418335

Browse files
committed
Added filter tests
1 parent 5932478 commit a418335

File tree

5 files changed

+228
-9
lines changed

5 files changed

+228
-9
lines changed

adc_eval/eval/spectrum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def plot_spectrum(
132132
fig, ax = plt.subplots(figsize=(15, 8))
133133
ax.plot(freq / xscale, psd_out)
134134
ax.set_ylabel(f"Power Spectrum ({yunits})", fontsize=18)
135-
ax.set_xlabel(f"Frequency ({fscale})", fontsize=16)
135+
ax.set_xlabel(f"Frequency ({fscale[0]})", fontsize=16)
136136
ax.set_title("Output Power Spectrum", fontsize=16)
137137
ax.set_xlim([xmin, fs / (2 * xscale)])
138138
ax.set_ylim([1.1 * min(psd_out), 1])

adc_eval/filter.py renamed to adc_eval/filt.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from scipy.signal import remez, freqz
55
import matplotlib.pyplot as plt
6+
from adc_eval import signals
67
from adc_eval.eval import spectrum
78
from adc_eval.eval import calc
89

@@ -38,13 +39,35 @@ class CICDecimate:
3839

3940
def __init__(self, dec=2, order=1, fs=1):
4041
"""Initialize the CIC filter."""
41-
self.dec = dec
42-
self.order = order
42+
self._dec = dec
43+
self._order = order
4344
self.fs = fs
44-
self.gain = self.dec**self.order
45+
self.gain = dec**order
4546
self._xout = None
4647
self._xfilt = None
4748

49+
@property
50+
def dec(self):
51+
"""Returns the decimation factor."""
52+
return self._dec
53+
54+
@dec.setter
55+
def dec(self, value):
56+
"""Sets the decimation factor."""
57+
self._dec = value
58+
self.gain = value**self._order
59+
60+
@property
61+
def order(self):
62+
"""Returns the order of the filter."""
63+
return self._order
64+
65+
@order.setter
66+
def order(self, value):
67+
"""Sets the filter order."""
68+
self._order = value
69+
self.gain = self.dec**value
70+
4871
@property
4972
def out(self):
5073
"""Filtered and decimated output data."""
@@ -70,9 +93,11 @@ def filt(self, xarray):
7093

7194
self._xfilt = ycomb / self.gain
7295

73-
def decimate(self):
96+
def decimate(self, xarray=None):
7497
"""decimation routine."""
75-
self._xout = self._xfilt[:: self.dec]
98+
if xarray is None:
99+
xarray = self._xfilt
100+
self._xout = xarray[:: self.dec]
76101

77102
def run(self, xarray):
78103
"""Runs filtering and decimation on input list."""
@@ -81,8 +106,7 @@ def run(self, xarray):
81106

82107
def response(self, fft, no_plot=False):
83108
"""Plots the frequency response of the pre-decimated filter."""
84-
xin = np.zeros(fft)
85-
xin[0] = 1
109+
xin = signals.impulse(fft)
86110
self.filt(xin)
87111
(freq, psd, stats) = spectrum.analyze(
88112
self._xfilt * fft / np.sqrt(2),
@@ -147,7 +171,7 @@ def __init__(self, dec=1, fs=1, bit_depth=16, coeffs=None):
147171
self.dec = dec
148172
self.fs = fs
149173
self.bit_depth = bit_depth
150-
self.ntaps = np.size(coeffs) if coeffs is None else 0
174+
self.ntaps = np.size(coeffs) if coeffs is not None else 0
151175
self.yfilt = None
152176
self._out = None
153177

adc_eval/signals.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,10 @@ def sin(t, amp=0.5, offset=0.5, freq=1e3, ph0=0):
1616
def noise(t, mean=0, std=0.1):
1717
"""Generate random noise."""
1818
return np.random.normal(mean, std, size=len(t))
19+
20+
21+
def impulse(nsamp, mag=1):
22+
"""Generate an impulse input."""
23+
data = np.zeros(nsamp)
24+
data[0] = mag
25+
return data

tests/test_filt.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Test the filter module."""
2+
3+
import pytest
4+
import numpy as np
5+
from unittest import mock
6+
from adc_eval import filt
7+
from adc_eval import signals
8+
9+
10+
@pytest.mark.parametrize("dec", np.random.randint(1, 20, 4))
11+
def test_cic_decimate_set_dec_updates_gain(dec):
12+
"""Tests that changing decimation factor updates gain."""
13+
cicfilt = filt.CICDecimate(dec=1, order=2)
14+
assert cicfilt.gain == 1
15+
16+
cicfilt.dec = dec
17+
assert cicfilt.gain == (dec**2)
18+
19+
20+
@pytest.mark.parametrize("order", np.random.randint(1, 20, 4))
21+
def test_cic_decimate_set_order_updates_gain(order):
22+
"""Tests that changing filter order updates gain."""
23+
cicfilt = filt.CICDecimate(dec=2, order=1)
24+
assert cicfilt.gain == 2
25+
26+
cicfilt.order = order
27+
assert cicfilt.gain == (2**order)
28+
29+
30+
def test_cic_decimate_returns_ndarray():
31+
"""Tests the CICDecimate output data conversion."""
32+
cicfilt = filt.CICDecimate()
33+
data = np.random.randn(100).tolist()
34+
cicfilt._xout = data
35+
36+
assert type(cicfilt.out) == type(np.array(list()))
37+
assert cicfilt.out.all() == np.array(data).all()
38+
39+
40+
@pytest.mark.parametrize("dec", np.random.randint(1, 20, 4))
41+
def test_cic_decimate_function(dec):
42+
"""Tests the CICDecimate decimate function."""
43+
cicfilt = filt.CICDecimate(dec=dec)
44+
data = np.random.randn(100)
45+
cicfilt.decimate(data)
46+
47+
exp_result = data[::dec]
48+
49+
assert cicfilt.out.size == exp_result.size
50+
assert cicfilt.out.all() == exp_result.all()
51+
52+
53+
def test_cic_decimate_function_none_input():
54+
"""Tests the CICDecimate decimate function with no input arg."""
55+
cicfilt = filt.CICDecimate(dec=1)
56+
data = np.random.randn(100)
57+
cicfilt._xfilt = data
58+
cicfilt.decimate()
59+
60+
exp_result = data
61+
62+
assert cicfilt.out.size == exp_result.size
63+
assert cicfilt.out.all() == exp_result.all()
64+
65+
66+
@pytest.mark.parametrize("nlen", np.random.randint(8, 2**10, 4))
67+
def test_cic_decimate_all_ones(nlen):
68+
"""Test the CICDecimate filtering with all ones."""
69+
cicfilt = filt.CICDecimate(dec=1, order=1)
70+
data = np.ones(nlen)
71+
cicfilt.run(data)
72+
73+
exp_data = data.copy()
74+
exp_data[0] = 0
75+
76+
assert cicfilt.out.all() == exp_data.all()
77+
78+
79+
@pytest.mark.parametrize("nlen", np.random.randint(8, 2**10, 4))
80+
def test_cic_decimate_all_zeros(nlen):
81+
"""Test the CICDecimate filtering with all zeros."""
82+
cicfilt = filt.CICDecimate(dec=1, order=1)
83+
data = np.zeros(nlen)
84+
cicfilt.run(data)
85+
86+
exp_data = data.copy()
87+
88+
assert cicfilt.out.all() == exp_data.all()
89+
90+
91+
@pytest.mark.parametrize("nlen", np.random.randint(8, 2**10, 4))
92+
def test_cic_decimate_impulse(nlen):
93+
"""Test the CICDecimate filtering with impulse."""
94+
cicfilt = filt.CICDecimate(dec=1, order=1)
95+
data = signals.impulse(nlen)
96+
cicfilt.run(data)
97+
98+
exp_data = np.concatenate([[0], data[0:-1]])
99+
100+
assert cicfilt.out.all() == exp_data.all()
101+
102+
103+
def test_fir_lowpass_returns_ndarray():
104+
"""Tests the FIRLowPass output data conversion."""
105+
fir = filt.FIRLowPass()
106+
data = np.random.randn(100).tolist()
107+
fir._out = data
108+
109+
assert type(fir.out) == type(np.array(list()))
110+
assert fir.out.all() == np.array(data).all()
111+
112+
113+
@pytest.mark.parametrize("dec", np.random.randint(1, 20, 4))
114+
def test_fir_decimate_function(dec):
115+
"""Tests the FIRLowPass decimate function."""
116+
fir = filt.FIRLowPass(dec=dec)
117+
data = np.random.randn(100)
118+
fir.decimate(data)
119+
120+
exp_result = data[::dec]
121+
122+
assert fir.out.size == exp_result.size
123+
assert fir.out.all() == exp_result.all()
124+
125+
126+
def test_fir_decimate_function_none_input():
127+
"""Tests the FIRLowPass decimate function with no input arg."""
128+
fir = filt.FIRLowPass(dec=1)
129+
data = np.random.randn(100)
130+
fir.yfilt = data
131+
fir.decimate()
132+
133+
exp_result = data
134+
135+
assert fir.out.size == exp_result.size
136+
assert fir.out.all() == exp_result.all()
137+
138+
139+
@mock.patch("adc_eval.filt.remez")
140+
def test_fir_lowpass_tap_generation(mock_remez, capfd):
141+
"""Tests the FIRLowPass decimate function."""
142+
fir = filt.FIRLowPass()
143+
fir.ntaps = 3
144+
fir.bit_depth = 12
145+
mock_remez.return_value = np.ones(3)
146+
147+
(taps, coeffs) = fir.generate_taps(0.1)
148+
149+
captured = capfd.readouterr()
150+
exp_coeffs = [2**12, 2**12, 2**12]
151+
152+
assert "WARNING" in captured.out
153+
assert taps == 3
154+
assert coeffs == exp_coeffs
155+
156+
157+
@pytest.mark.parametrize("ntaps", np.random.randint(3, 511, 5))
158+
def test_fir_lowpass_run(ntaps):
159+
"""Tests the FIRLowPass run function."""
160+
fir = filt.FIRLowPass()
161+
fir.ntaps = ntaps
162+
fir.bit_depth = 10
163+
fir.coeffs = 2**10 * np.ones(ntaps)
164+
data = signals.impulse(2**12)
165+
fir.run(data)
166+
167+
exp_sum = np.ceil((ntaps + 1) / 2)
168+
out_sum = sum(fir.out)
169+
170+
tap_val = int(exp_sum)
171+
172+
assert fir.out.size == data.size
173+
assert max(fir.out) == 1
174+
assert min(fir.out) == 0
175+
assert fir.out[0:tap_val].all() == 1
176+
assert fir.out[tap_val + 1 :].all() == 0
177+
assert out_sum == exp_sum

tests/test_signals.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,14 @@ def test_noise_length(std):
6868

6969
# Now check that noise is gaussian
7070
assert shapiro.pvalue < 0.01
71+
72+
73+
@pytest.mark.parametrize("nlen", np.random.randint(2, 2**12, 3))
74+
@pytest.mark.parametrize("mag", np.random.uniform(0.1, 100, 3))
75+
def test_impulse(nlen, mag):
76+
"""Test impulse generation with random length and amplitude."""
77+
data = signals.impulse(nlen, mag)
78+
79+
assert data.size == nlen
80+
assert data[0] == mag
81+
assert data[1:].all() == 0

0 commit comments

Comments
 (0)