Skip to content

Commit 3b9f6be

Browse files
committed
work in progress
1 parent 55a182b commit 3b9f6be

File tree

2 files changed

+89
-41
lines changed

2 files changed

+89
-41
lines changed

torch_wavelets/extractors.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
class TemporalFilterBankBase(metaclass=ABCMeta):
3232

33-
def __init__(self, dt=1.0, dj=0.125, wavelet=Morlet(), unbias=False, signal_length=None):
33+
def __init__(self, dt=1.0, dj=0.125, wavelet=Morlet(), unbias=False, signal_length=512):
3434

3535
self._dt = dt
3636
self._dj = dj
@@ -47,16 +47,16 @@ def compute(self, signal):
4747
raise NotImplementedError
4848

4949
def _init_filters(self):
50-
filters = []
51-
for i, scale in enumerate(self._scales):
50+
filters = [None]*len(self.scales)
51+
for scale_idx, scale in enumerate(self._scales):
5252
# number of points needed to capture wavelet
5353
M = 10 * scale / self.dt
5454
# times to use, centred at zero
5555
t = np.arange((-M + 1) / 2., (M + 1) / 2.) * dt
5656
if len(t) % 2 == 0: t = t[0:-1] # requires odd filter size
5757
# sample wavelet and normalise
5858
norm = (self.dt / scale) ** .5
59-
filters[i] = norm * self.wavelet(t, scale)
59+
filters[scale_idx] = norm * self.wavelet(t, scale)
6060
return filters
6161

6262
def compute_optimal_scales(self):
@@ -78,11 +78,11 @@ def func_to_solve(s):
7878
return self.fourier_period(s) - 2 * dt
7979
return scipy.optimize.fsolve(func_to_solve, 1)[0]
8080

81-
def power(self, signal):
81+
def power(self, x):
8282
if self.unbias:
83-
return (np.abs(self.compute(signal)).T ** 2 / self._scales).T
83+
return (np.abs(self.compute(x)).T ** 2 / self.scales).T
8484
else:
85-
return np.abs(self.compute(signal)) ** 2
85+
return np.abs(self.compute(x)) ** 2
8686

8787
@property
8888
def fourier_period(self):
@@ -136,13 +136,15 @@ def compute(self, x):
136136
num_examples = x.shape[0]
137137
output = np.zeros((num_examples, len(self.scales), x.shape[-1]), dtype=np.complex)
138138
for example_idx in range(num_examples):
139-
output[example_idx] = self.compute_single(x[example_idx])
140-
return np.squeeze(output, 0)
139+
output[example_idx] = self._compute_single(x[example_idx])
140+
if num_examples == 1:
141+
output = output.squeeze(0)
142+
return output
141143

142-
def compute_single(self, x):
144+
def _compute_single(self, x):
143145
assert x.ndim == 1, 'input signal must have single dimension.'
144146
output = np.zeros((len(self.scales), len(x)), dtype=np.complex)
145-
for scale_idx, filt in enumerate(self._filters)
147+
for scale_idx, filt in enumerate(self._filters):
146148
output[scale_idx,:] = scipy.signal.fftconvolve(x, filt, mode='same')
147149
return output
148150

@@ -199,13 +201,36 @@ def _get_padding(padding_type, kernel_size):
199201

200202
if __name__ == "__main__":
201203

202-
dt = 1.0
203-
dj = 0.125
204-
wavelet = Morlet(w0=6)
204+
import torch_wavelets.utils as utils
205+
import matplotlib.pyplot as plt
206+
207+
fps = 20
208+
dt = 1.0/fps
209+
dj = 0.125
210+
w0 = 6
205211
unbias = False
212+
wavelet = Morlet()
213+
214+
t_min = 0
215+
t_max = 10
216+
t = np.linspace(t_min, t_max, (t_max-t_min)*fps)
217+
218+
batch_size = 12
219+
220+
# Generate a batch of sine waves with random frequency
221+
random_frequencies = np.random.uniform(-0.5, 2.0, size=batch_size)
222+
batch = np.asarray([np.sin(2*np.pi*f*t) for f in random_frequencies])
223+
224+
wa = TemporalFilterBankSciPy(dt, dj, wavelet, unbias)
225+
power = wa.power(batch)
226+
227+
fig, ax = plt.subplots(3, 4, figsize=(16,8))
228+
ax = ax.flatten()
229+
for i in range(batch_size):
230+
utils.plot_scalogram(power[i], wa.scales, t, ax=ax[i])
231+
ax[i].axhline(1.0 / random_frequencies[i], lw=1, color='k')
232+
plt.show()
206233

207-
cls_scipy = TemporalFilterBankSciPy(dt, dj, wavelet, unbias)
208-
cls_torch = TemporalFilterBankTorch(dt, dj, wavelet, unbias)
209234

210235

211236

torch_wavelets/utils.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,51 @@
1717
from __future__ import print_function
1818

1919
import numpy as np
20-
from torch_wavelets.wavelets import *
21-
22-
23-
24-
def scale_distribution(wavelet, min_period, max_period, dj=0.125):
25-
"""
26-
Given a minimum and maximum period compute a distribution of scales. The choice of
27-
dj depends on the width in spectral space of the wavelet function.
28-
For the Morlet, dj=0.5 is the largest that still adequately samples scale.
29-
Smaller dj gives finer scale resolution.
30-
31-
:param wavelet: wavelet instance from (Morlet, Ricker, MexicanHat, Marr)
32-
:param min_period: float, minimum period
33-
:param max_period: float, maximum period
34-
:param dj: float, scale sample density
35-
:return: np.ndarray, containing a list of scales
36-
"""
37-
38-
assert isinstance(wavelet, (Morlet, Ricker, Marr, Mexican_hat))
39-
scale_min = wavelet.scale_from_period(min_period)
40-
scale_max = wavelet.scale_from_period(max_period)
41-
num_scales = int(np.ceil((1.0/dj) * np.log2(scale_max/scale_min)))
42-
j = np.arange(0, num_scales+1)
43-
scales = scale_min*np.power(2, j*dj)
44-
return scales
20+
import matplotlib.pyplot as plt
21+
22+
# This makes the color map of same height as the image
23+
import matplotlib.ticker as ticker
24+
from mpl_toolkits.axes_grid1 import make_axes_locatable
25+
26+
27+
def plot_scalogram(power, scales, t, normalize_columns=True, cmap=None, ax=None, scale_legend=True):
28+
29+
if not cmap:
30+
cmap = plt.get_cmap("PuBu_r")
31+
32+
if ax is None:
33+
fig, ax = plt.subplots()
34+
35+
if normalize_columns:
36+
power = power/np.max(power, axis=0)
37+
38+
T, S = np.meshgrid(t, scales)
39+
cnt = ax.contourf(T, S, power, 100, cmap=cmap)
40+
41+
# Fix for saving as PDF (aliasing)
42+
for c in cnt.collections:
43+
c.set_edgecolor("face")
44+
45+
ax.set_yscale('log')
46+
ax.set_ylabel("Scale (Log Scale)")
47+
ax.set_xlabel("Time (s)")
48+
ax.set_title("Wavelet Power Spectrum")
49+
50+
if scale_legend:
51+
52+
def format_axes_label(x, pos):
53+
return "{:.2f}".format(x)
54+
55+
divider = make_axes_locatable(ax)
56+
cax = divider.append_axes("right", size="5%", pad=0.05)
57+
plt.colorbar(
58+
cnt, cax=cax, ticks=[np.min(power), 0, np.max(power)],
59+
format=ticker.FuncFormatter(format_axes_label))
60+
61+
return ax
62+
63+
64+
65+
66+
67+

0 commit comments

Comments
 (0)