Skip to content

Commit b5d8027

Browse files
authored
Run functional tests on GPU as well as CPU (#1475)
1 parent bdd7b33 commit b5d8027

File tree

2 files changed

+68
-68
lines changed

2 files changed

+68
-68
lines changed

test/torchaudio_unittest/functional/functional_cpu_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ def test_lfilter_9th_order_filter_stability(self):
1616
super().test_lfilter_9th_order_filter_stability()
1717

1818

19-
class TestFunctionalFloat64(Functional, FunctionalCPUOnly, PytorchTestCase):
19+
class TestFunctionalFloat64(Functional, PytorchTestCase):
2020
dtype = torch.float64
2121
device = torch.device('cpu')
2222

2323

24-
class TestFunctionalComplex64(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
24+
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
2525
complex_dtype = torch.complex64
2626
real_dtype = torch.float32
2727
device = torch.device('cpu')
2828

2929

30-
class TestFunctionalComplex128(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
30+
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
3131
complex_dtype = torch.complex128
3232
real_dtype = torch.float64
3333
device = torch.device('cpu')

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -93,73 +93,17 @@ def test_spectogram_grad_at_zero(self, power):
9393
spec.sum().backward()
9494
assert not x.grad.isnan().sum()
9595

96-
97-
class FunctionalComplex(TestBaseMixin):
98-
complex_dtype = None
99-
real_dtype = None
100-
device = None
101-
102-
@nested_params(
103-
[0.5, 1.01, 1.3],
104-
[True, False],
105-
)
106-
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
107-
"""Verify the output shape of phase vocoder"""
108-
hop_length = 256
109-
num_freq = 1025
110-
num_frames = 400
111-
batch_size = 2
112-
113-
torch.random.manual_seed(42)
114-
spec = torch.randn(
115-
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
116-
if test_pseudo_complex:
117-
spec = torch.view_as_real(spec)
118-
119-
phase_advance = torch.linspace(
120-
0,
121-
np.pi * hop_length,
122-
num_freq,
123-
dtype=self.real_dtype, device=self.device)[..., None]
124-
125-
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
126-
127-
assert spec.dim() == spec_stretch.dim()
128-
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
129-
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
130-
assert output_shape == expected_shape
131-
132-
133-
class FunctionalCPUOnly(TestBaseMixin):
134-
def test_create_fb_matrix_no_warning_high_n_freq(self):
135-
with warnings.catch_warnings(record=True) as w:
136-
warnings.simplefilter("always")
137-
F.create_fb_matrix(288, 0, 8000, 128, 16000)
138-
assert len(w) == 0
139-
140-
def test_create_fb_matrix_no_warning_low_n_mels(self):
141-
with warnings.catch_warnings(record=True) as w:
142-
warnings.simplefilter("always")
143-
F.create_fb_matrix(201, 0, 8000, 89, 16000)
144-
assert len(w) == 0
145-
146-
def test_create_fb_matrix_warning(self):
147-
with warnings.catch_warnings(record=True) as w:
148-
warnings.simplefilter("always")
149-
F.create_fb_matrix(201, 0, 8000, 128, 16000)
150-
assert len(w) == 1
151-
15296
def test_compute_deltas_one_channel(self):
153-
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
154-
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
97+
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
98+
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
15599
computed = F.compute_deltas(specgram, win_length=3)
156100
self.assertEqual(computed, expected)
157101

158102
def test_compute_deltas_two_channels(self):
159103
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
160-
[1.0, 2.0, 3.0, 4.0]]])
104+
[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
161105
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
162-
[0.5, 1.0, 1.0, 0.5]]])
106+
[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
163107
computed = F.compute_deltas(specgram, win_length=3)
164108
self.assertEqual(computed, expected)
165109

@@ -190,7 +134,7 @@ def test_amplitude_to_DB_reversible(self, shape):
190134
db_mult = math.log10(max(amin, ref))
191135

192136
torch.manual_seed(0)
193-
spec = torch.rand(*shape) * 200
137+
spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200
194138

195139
# Spectrogram amplitude -> DB -> amplitude
196140
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
@@ -218,7 +162,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
218162
# each spectrogram still need to be predictable. The max determines the
219163
# decibel cutoff, and the distance from the min must be large enough
220164
# that it triggers a clamp.
221-
spec = torch.rand(*shape)
165+
spec = torch.rand(*shape, dtype=self.dtype, device=self.device)
222166
# Ensure each spectrogram has a min of 0 and a max of 1.
223167
spec -= spec.amin([-2, -1])[..., None, None]
224168
spec /= spec.amax([-2, -1])[..., None, None]
@@ -245,7 +189,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
245189
)
246190
def test_complex_norm(self, shape, power):
247191
torch.random.manual_seed(42)
248-
complex_tensor = torch.randn(*shape)
192+
complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
249193
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
250194
norm_tensor = F.complex_norm(complex_tensor, power)
251195
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
@@ -255,7 +199,7 @@ def test_complex_norm(self, shape, power):
255199
)
256200
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
257201
torch.random.manual_seed(42)
258-
specgram = torch.randn(*shape)
202+
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
259203
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
260204

261205
other_axis = 1 if axis == 2 else 2
@@ -271,7 +215,7 @@ def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
271215
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
272216
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
273217
torch.random.manual_seed(42)
274-
specgrams = torch.randn(4, 2, 1025, 400)
218+
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
275219

276220
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
277221

@@ -282,3 +226,59 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
282226

283227
assert mask_specgrams.size() == specgrams.size()
284228
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
229+
230+
231+
class FunctionalComplex(TestBaseMixin):
232+
complex_dtype = None
233+
real_dtype = None
234+
device = None
235+
236+
@nested_params(
237+
[0.5, 1.01, 1.3],
238+
[True, False],
239+
)
240+
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
241+
"""Verify the output shape of phase vocoder"""
242+
hop_length = 256
243+
num_freq = 1025
244+
num_frames = 400
245+
batch_size = 2
246+
247+
torch.random.manual_seed(42)
248+
spec = torch.randn(
249+
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
250+
if test_pseudo_complex:
251+
spec = torch.view_as_real(spec)
252+
253+
phase_advance = torch.linspace(
254+
0,
255+
np.pi * hop_length,
256+
num_freq,
257+
dtype=self.real_dtype, device=self.device)[..., None]
258+
259+
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
260+
261+
assert spec.dim() == spec_stretch.dim()
262+
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
263+
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
264+
assert output_shape == expected_shape
265+
266+
267+
class FunctionalCPUOnly(TestBaseMixin):
268+
def test_create_fb_matrix_no_warning_high_n_freq(self):
269+
with warnings.catch_warnings(record=True) as w:
270+
warnings.simplefilter("always")
271+
F.create_fb_matrix(288, 0, 8000, 128, 16000)
272+
assert len(w) == 0
273+
274+
def test_create_fb_matrix_no_warning_low_n_mels(self):
275+
with warnings.catch_warnings(record=True) as w:
276+
warnings.simplefilter("always")
277+
F.create_fb_matrix(201, 0, 8000, 89, 16000)
278+
assert len(w) == 0
279+
280+
def test_create_fb_matrix_warning(self):
281+
with warnings.catch_warnings(record=True) as w:
282+
warnings.simplefilter("always")
283+
F.create_fb_matrix(201, 0, 8000, 128, 16000)
284+
assert len(w) == 1

0 commit comments

Comments
 (0)