1
1
import pytest
2
- from torchaudio_augmentations import Compose , ComposeMany , RandomResizedCrop
2
+ import torch
3
+ from torchaudio_augmentations import (
4
+ Compose ,
5
+ ComposeMany ,
6
+ RandomResizedCrop ,
7
+ Reverb ,
8
+ )
3
9
4
10
from .utils import generate_waveform
5
11
10
16
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
11
17
def test_compose (num_channels ):
12
18
audio = generate_waveform (sample_rate , num_samples , num_channels )
13
- transform = Compose ([RandomResizedCrop (num_samples ),])
19
+ transform = Compose (
20
+ [
21
+ RandomResizedCrop (num_samples ),
22
+ ]
23
+ )
14
24
15
25
t_audio = transform (audio )
16
26
assert t_audio .shape [0 ] == num_channels
@@ -19,14 +29,21 @@ def test_compose(num_channels):
19
29
20
30
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
21
31
def test_compose_many (num_channels ):
22
- num_augmented_samples = 10
32
+ num_augmented_samples = 4
23
33
24
34
audio = generate_waveform (sample_rate , num_samples , num_channels )
25
35
transform = ComposeMany (
26
- [RandomResizedCrop (num_samples ),], num_augmented_samples = num_augmented_samples ,
36
+ [
37
+ RandomResizedCrop (num_samples ),
38
+ Reverb (sample_rate ),
39
+ ],
40
+ num_augmented_samples = num_augmented_samples ,
27
41
)
28
42
29
43
t_audio = transform (audio )
30
44
assert t_audio .shape [0 ] == num_augmented_samples
31
45
assert t_audio .shape [1 ] == num_channels
32
46
assert t_audio .shape [2 ] == num_samples
47
+
48
+ for n in range (1 , num_augmented_samples ):
49
+ assert torch .all (t_audio [0 ].eq (t_audio [n ])) == False
0 commit comments