@@ -93,73 +93,17 @@ def test_spectogram_grad_at_zero(self, power):
93
93
spec .sum ().backward ()
94
94
assert not x .grad .isnan ().sum ()
95
95
96
-
97
- class FunctionalComplex (TestBaseMixin ):
98
- complex_dtype = None
99
- real_dtype = None
100
- device = None
101
-
102
- @nested_params (
103
- [0.5 , 1.01 , 1.3 ],
104
- [True , False ],
105
- )
106
- def test_phase_vocoder_shape (self , rate , test_pseudo_complex ):
107
- """Verify the output shape of phase vocoder"""
108
- hop_length = 256
109
- num_freq = 1025
110
- num_frames = 400
111
- batch_size = 2
112
-
113
- torch .random .manual_seed (42 )
114
- spec = torch .randn (
115
- batch_size , num_freq , num_frames , dtype = self .complex_dtype , device = self .device )
116
- if test_pseudo_complex :
117
- spec = torch .view_as_real (spec )
118
-
119
- phase_advance = torch .linspace (
120
- 0 ,
121
- np .pi * hop_length ,
122
- num_freq ,
123
- dtype = self .real_dtype , device = self .device )[..., None ]
124
-
125
- spec_stretch = F .phase_vocoder (spec , rate = rate , phase_advance = phase_advance )
126
-
127
- assert spec .dim () == spec_stretch .dim ()
128
- expected_shape = torch .Size ([batch_size , num_freq , int (np .ceil (num_frames / rate ))])
129
- output_shape = (torch .view_as_complex (spec_stretch ) if test_pseudo_complex else spec_stretch ).shape
130
- assert output_shape == expected_shape
131
-
132
-
133
- class FunctionalCPUOnly (TestBaseMixin ):
134
- def test_create_fb_matrix_no_warning_high_n_freq (self ):
135
- with warnings .catch_warnings (record = True ) as w :
136
- warnings .simplefilter ("always" )
137
- F .create_fb_matrix (288 , 0 , 8000 , 128 , 16000 )
138
- assert len (w ) == 0
139
-
140
- def test_create_fb_matrix_no_warning_low_n_mels (self ):
141
- with warnings .catch_warnings (record = True ) as w :
142
- warnings .simplefilter ("always" )
143
- F .create_fb_matrix (201 , 0 , 8000 , 89 , 16000 )
144
- assert len (w ) == 0
145
-
146
- def test_create_fb_matrix_warning (self ):
147
- with warnings .catch_warnings (record = True ) as w :
148
- warnings .simplefilter ("always" )
149
- F .create_fb_matrix (201 , 0 , 8000 , 128 , 16000 )
150
- assert len (w ) == 1
151
-
152
96
def test_compute_deltas_one_channel (self ):
153
- specgram = torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ]]])
154
- expected = torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ]]])
97
+ specgram = torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = self . dtype , device = self . device )
98
+ expected = torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ]]], dtype = self . dtype , device = self . device )
155
99
computed = F .compute_deltas (specgram , win_length = 3 )
156
100
self .assertEqual (computed , expected )
157
101
158
102
def test_compute_deltas_two_channels (self ):
159
103
specgram = torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ],
160
- [1.0 , 2.0 , 3.0 , 4.0 ]]])
104
+ [1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = self . dtype , device = self . device )
161
105
expected = torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ],
162
- [0.5 , 1.0 , 1.0 , 0.5 ]]])
106
+ [0.5 , 1.0 , 1.0 , 0.5 ]]], dtype = self . dtype , device = self . device )
163
107
computed = F .compute_deltas (specgram , win_length = 3 )
164
108
self .assertEqual (computed , expected )
165
109
@@ -190,7 +134,7 @@ def test_amplitude_to_DB_reversible(self, shape):
190
134
db_mult = math .log10 (max (amin , ref ))
191
135
192
136
torch .manual_seed (0 )
193
- spec = torch .rand (* shape ) * 200
137
+ spec = torch .rand (* shape , dtype = self . dtype , device = self . device ) * 200
194
138
195
139
# Spectrogram amplitude -> DB -> amplitude
196
140
db = F .amplitude_to_DB (spec , amplitude_mult , amin , db_mult , top_db = None )
@@ -218,7 +162,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
218
162
# each spectrogram still need to be predictable. The max determines the
219
163
# decibel cutoff, and the distance from the min must be large enough
220
164
# that it triggers a clamp.
221
- spec = torch .rand (* shape )
165
+ spec = torch .rand (* shape , dtype = self . dtype , device = self . device )
222
166
# Ensure each spectrogram has a min of 0 and a max of 1.
223
167
spec -= spec .amin ([- 2 , - 1 ])[..., None , None ]
224
168
spec /= spec .amax ([- 2 , - 1 ])[..., None , None ]
@@ -245,7 +189,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
245
189
)
246
190
def test_complex_norm (self , shape , power ):
247
191
torch .random .manual_seed (42 )
248
- complex_tensor = torch .randn (* shape )
192
+ complex_tensor = torch .randn (* shape , dtype = self . dtype , device = self . device )
249
193
expected_norm_tensor = complex_tensor .pow (2 ).sum (- 1 ).pow (power / 2 )
250
194
norm_tensor = F .complex_norm (complex_tensor , power )
251
195
self .assertEqual (norm_tensor , expected_norm_tensor , atol = 1e-5 , rtol = 1e-5 )
@@ -255,7 +199,7 @@ def test_complex_norm(self, shape, power):
255
199
)
256
200
def test_mask_along_axis (self , shape , mask_param , mask_value , axis ):
257
201
torch .random .manual_seed (42 )
258
- specgram = torch .randn (* shape )
202
+ specgram = torch .randn (* shape , dtype = self . dtype , device = self . device )
259
203
mask_specgram = F .mask_along_axis (specgram , mask_param , mask_value , axis )
260
204
261
205
other_axis = 1 if axis == 2 else 2
@@ -271,7 +215,7 @@ def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
271
215
@parameterized .expand (list (itertools .product ([100 ], [0. , 30. ], [2 , 3 ])))
272
216
def test_mask_along_axis_iid (self , mask_param , mask_value , axis ):
273
217
torch .random .manual_seed (42 )
274
- specgrams = torch .randn (4 , 2 , 1025 , 400 )
218
+ specgrams = torch .randn (4 , 2 , 1025 , 400 , dtype = self . dtype , device = self . device )
275
219
276
220
mask_specgrams = F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis )
277
221
@@ -282,3 +226,59 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
282
226
283
227
assert mask_specgrams .size () == specgrams .size ()
284
228
assert (num_masked_columns < mask_param ).sum () == num_masked_columns .numel ()
229
+
230
+
231
+ class FunctionalComplex (TestBaseMixin ):
232
+ complex_dtype = None
233
+ real_dtype = None
234
+ device = None
235
+
236
+ @nested_params (
237
+ [0.5 , 1.01 , 1.3 ],
238
+ [True , False ],
239
+ )
240
+ def test_phase_vocoder_shape (self , rate , test_pseudo_complex ):
241
+ """Verify the output shape of phase vocoder"""
242
+ hop_length = 256
243
+ num_freq = 1025
244
+ num_frames = 400
245
+ batch_size = 2
246
+
247
+ torch .random .manual_seed (42 )
248
+ spec = torch .randn (
249
+ batch_size , num_freq , num_frames , dtype = self .complex_dtype , device = self .device )
250
+ if test_pseudo_complex :
251
+ spec = torch .view_as_real (spec )
252
+
253
+ phase_advance = torch .linspace (
254
+ 0 ,
255
+ np .pi * hop_length ,
256
+ num_freq ,
257
+ dtype = self .real_dtype , device = self .device )[..., None ]
258
+
259
+ spec_stretch = F .phase_vocoder (spec , rate = rate , phase_advance = phase_advance )
260
+
261
+ assert spec .dim () == spec_stretch .dim ()
262
+ expected_shape = torch .Size ([batch_size , num_freq , int (np .ceil (num_frames / rate ))])
263
+ output_shape = (torch .view_as_complex (spec_stretch ) if test_pseudo_complex else spec_stretch ).shape
264
+ assert output_shape == expected_shape
265
+
266
+
267
+ class FunctionalCPUOnly (TestBaseMixin ):
268
+ def test_create_fb_matrix_no_warning_high_n_freq (self ):
269
+ with warnings .catch_warnings (record = True ) as w :
270
+ warnings .simplefilter ("always" )
271
+ F .create_fb_matrix (288 , 0 , 8000 , 128 , 16000 )
272
+ assert len (w ) == 0
273
+
274
+ def test_create_fb_matrix_no_warning_low_n_mels (self ):
275
+ with warnings .catch_warnings (record = True ) as w :
276
+ warnings .simplefilter ("always" )
277
+ F .create_fb_matrix (201 , 0 , 8000 , 89 , 16000 )
278
+ assert len (w ) == 0
279
+
280
+ def test_create_fb_matrix_warning (self ):
281
+ with warnings .catch_warnings (record = True ) as w :
282
+ warnings .simplefilter ("always" )
283
+ F .create_fb_matrix (201 , 0 , 8000 , 128 , 16000 )
284
+ assert len (w ) == 1
0 commit comments