30
30
31
31
class TemporalFilterBankBase (metaclass = ABCMeta ):
32
32
33
- def __init__ (self , dt = 1.0 , dj = 0.125 , wavelet = Morlet (), unbias = False , signal_length = None ):
33
+ def __init__ (self , dt = 1.0 , dj = 0.125 , wavelet = Morlet (), unbias = False , signal_length = 512 ):
34
34
35
35
self ._dt = dt
36
36
self ._dj = dj
@@ -47,16 +47,16 @@ def compute(self, signal):
47
47
raise NotImplementedError
48
48
49
49
def _init_filters (self ):
50
- filters = []
51
- for i , scale in enumerate (self ._scales ):
50
+ filters = [None ] * len ( self . scales )
51
+ for scale_idx , scale in enumerate (self ._scales ):
52
52
# number of points needed to capture wavelet
53
53
M = 10 * scale / self .dt
54
54
# times to use, centred at zero
55
55
t = np .arange ((- M + 1 ) / 2. , (M + 1 ) / 2. ) * dt
56
56
if len (t ) % 2 == 0 : t = t [0 :- 1 ] # requires odd filter size
57
57
# sample wavelet and normalise
58
58
norm = (self .dt / scale ) ** .5
59
- filters [i ] = norm * self .wavelet (t , scale )
59
+ filters [scale_idx ] = norm * self .wavelet (t , scale )
60
60
return filters
61
61
62
62
def compute_optimal_scales (self ):
@@ -78,11 +78,11 @@ def func_to_solve(s):
78
78
return self .fourier_period (s ) - 2 * dt
79
79
return scipy .optimize .fsolve (func_to_solve , 1 )[0 ]
80
80
81
- def power (self , signal ):
81
+ def power (self , x ):
82
82
if self .unbias :
83
- return (np .abs (self .compute (signal )).T ** 2 / self ._scales ).T
83
+ return (np .abs (self .compute (x )).T ** 2 / self .scales ).T
84
84
else :
85
- return np .abs (self .compute (signal )) ** 2
85
+ return np .abs (self .compute (x )) ** 2
86
86
87
87
@property
88
88
def fourier_period (self ):
@@ -136,13 +136,15 @@ def compute(self, x):
136
136
num_examples = x .shape [0 ]
137
137
output = np .zeros ((num_examples , len (self .scales ), x .shape [- 1 ]), dtype = np .complex )
138
138
for example_idx in range (num_examples ):
139
- output [example_idx ] = self .compute_single (x [example_idx ])
140
- return np .squeeze (output , 0 )
139
+ output [example_idx ] = self ._compute_single (x [example_idx ])
140
+ if num_examples == 1 :
141
+ output = output .squeeze (0 )
142
+ return output
141
143
142
- def compute_single (self , x ):
144
+ def _compute_single (self , x ):
143
145
assert x .ndim == 1 , 'input signal must have single dimension.'
144
146
output = np .zeros ((len (self .scales ), len (x )), dtype = np .complex )
145
- for scale_idx , filt in enumerate (self ._filters )
147
+ for scale_idx , filt in enumerate (self ._filters ):
146
148
output [scale_idx ,:] = scipy .signal .fftconvolve (x , filt , mode = 'same' )
147
149
return output
148
150
@@ -199,13 +201,36 @@ def _get_padding(padding_type, kernel_size):
199
201
200
202
if __name__ == "__main__" :
201
203
202
- dt = 1.0
203
- dj = 0.125
204
- wavelet = Morlet (w0 = 6 )
204
+ import torch_wavelets .utils as utils
205
+ import matplotlib .pyplot as plt
206
+
207
+ fps = 20
208
+ dt = 1.0 / fps
209
+ dj = 0.125
210
+ w0 = 6
205
211
unbias = False
212
+ wavelet = Morlet ()
213
+
214
+ t_min = 0
215
+ t_max = 10
216
+ t = np .linspace (t_min , t_max , (t_max - t_min )* fps )
217
+
218
+ batch_size = 12
219
+
220
+ # Generate a batch of sine waves with random frequency
221
+ random_frequencies = np .random .uniform (- 0.5 , 2.0 , size = batch_size )
222
+ batch = np .asarray ([np .sin (2 * np .pi * f * t ) for f in random_frequencies ])
223
+
224
+ wa = TemporalFilterBankSciPy (dt , dj , wavelet , unbias )
225
+ power = wa .power (batch )
226
+
227
+ fig , ax = plt .subplots (3 , 4 , figsize = (16 ,8 ))
228
+ ax = ax .flatten ()
229
+ for i in range (batch_size ):
230
+ utils .plot_scalogram (power [i ], wa .scales , t , ax = ax [i ])
231
+ ax [i ].axhline (1.0 / random_frequencies [i ], lw = 1 , color = 'k' )
232
+ plt .show ()
206
233
207
- cls_scipy = TemporalFilterBankSciPy (dt , dj , wavelet , unbias )
208
- cls_torch = TemporalFilterBankTorch (dt , dj , wavelet , unbias )
209
234
210
235
211
236
0 commit comments