@@ -33,11 +33,37 @@ def test_random_resized_crop(num_channels):
33
33
assert audio .shape [1 ] == num_samples
34
34
35
35
36
+ @pytest .mark .parametrize (
37
+ ["batch_size" , "num_channels" ],
38
+ [
39
+ (1 , 1 ),
40
+ (4 , 1 ),
41
+ (16 , 1 ),
42
+ (1 , 2 ),
43
+ (4 , 2 ),
44
+ (16 , 2 ),
45
+ ],
46
+ )
47
+ def test_random_resized_crop_batched (batch_size , num_channels ):
48
+
49
+ num_samples = 22050 * 5
50
+ audio = generate_waveform (sample_rate , num_samples , num_channels )
51
+ audio = audio .repeat (batch_size , 1 , 1 )
52
+
53
+ transform = Compose ([RandomResizedCrop (num_samples )])
54
+
55
+ audio = transform (audio )
56
+ assert audio .shape [0 ] == batch_size
57
+ assert audio .shape [1 ] == num_channels
58
+ assert audio .shape [2 ] == num_samples
59
+
60
+
36
61
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
37
62
def test_polarity (num_channels ):
38
- audio = generate_waveform (sample_rate , num_samples ,
39
- num_channels = num_channels )
40
- transform = Compose ([PolarityInversion ()],)
63
+ audio = generate_waveform (sample_rate , num_samples , num_channels = num_channels )
64
+ transform = Compose (
65
+ [PolarityInversion ()],
66
+ )
41
67
42
68
t_audio = transform (audio )
43
69
assert (t_audio == torch .neg (audio )).all ()
@@ -47,7 +73,9 @@ def test_polarity(num_channels):
47
73
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
48
74
def test_filter (num_channels ):
49
75
audio = generate_waveform (sample_rate , num_samples , num_channels )
50
- transform = Compose ([HighLowPass (sample_rate = sample_rate )],)
76
+ transform = Compose (
77
+ [HighLowPass (sample_rate = sample_rate )],
78
+ )
51
79
t_audio = transform (audio )
52
80
# torchaudio.save("tests/filter.wav", t_audio, sample_rate=sample_rate)
53
81
assert t_audio .shape == audio .shape
@@ -56,7 +84,9 @@ def test_filter(num_channels):
56
84
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
57
85
def test_delay (num_channels ):
58
86
audio = generate_waveform (sample_rate , num_samples , num_channels )
59
- transform = Compose ([Delay (sample_rate = sample_rate )],)
87
+ transform = Compose (
88
+ [Delay (sample_rate = sample_rate )],
89
+ )
60
90
61
91
t_audio = transform (audio )
62
92
# torchaudio.save("tests/delay.wav", t_audio, sample_rate=sample_rate)
@@ -66,7 +96,9 @@ def test_delay(num_channels):
66
96
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
67
97
def test_gain (num_channels ):
68
98
audio = generate_waveform (sample_rate , num_samples , num_channels )
69
- transform = Compose ([Gain ()],)
99
+ transform = Compose (
100
+ [Gain ()],
101
+ )
70
102
71
103
t_audio = transform (audio )
72
104
# torchaudio.save("tests/gain.wav", t_audio, sample_rate=sample_rate)
@@ -76,7 +108,9 @@ def test_gain(num_channels):
76
108
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
77
109
def test_noise (num_channels ):
78
110
audio = generate_waveform (sample_rate , num_samples , num_channels )
79
- transform = Compose ([Noise (min_snr = 0.5 , max_snr = 1 )],)
111
+ transform = Compose (
112
+ [Noise (min_snr = 0.5 , max_snr = 1 )],
113
+ )
80
114
81
115
t_audio = transform (audio )
82
116
# torchaudio.save("tests/noise.wav", t_audio, sample_rate=sample_rate)
@@ -87,17 +121,41 @@ def test_noise(num_channels):
87
121
def test_pitch (num_channels ):
88
122
audio = generate_waveform (sample_rate , num_samples , num_channels )
89
123
transform = Compose (
90
- [PitchShift (n_samples = num_samples , sample_rate = sample_rate )],)
124
+ [PitchShift (n_samples = num_samples , sample_rate = sample_rate )],
125
+ )
91
126
92
127
t_audio = transform (audio )
93
- # torchaudio.save("tests/pitch.wav", t_audio, sample_rate=sample_rate)
128
+ # torchaudio.save("tests/pitch.wav", audio, sample_rate=sample_rate)
129
+ # torchaudio.save("tests/t_pitch.wav", t_audio, sample_rate=sample_rate)
94
130
assert t_audio .shape == audio .shape
95
131
96
132
133
+ def test_pitch_shift_fast_ratios ():
134
+ ps = PitchShift (
135
+ n_samples = num_samples ,
136
+ sample_rate = sample_rate ,
137
+ pitch_shift_min = - 5 ,
138
+ pitch_shift_max = 5 ,
139
+ )
140
+ assert len (ps .fast_shifts ) == 20
141
+
142
+
143
+ def test_pitch_shift_no_fast_ratios ():
144
+ with pytest .raises (ValueError ):
145
+ ps = PitchShift (
146
+ n_samples = num_samples ,
147
+ sample_rate = sample_rate ,
148
+ pitch_shift_min = 4 ,
149
+ pitch_shift_max = 4 ,
150
+ )
151
+
152
+
97
153
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
98
154
def test_reverb (num_channels ):
99
155
audio = generate_waveform (sample_rate , num_samples , num_channels )
100
- transform = Compose ([Reverb (sample_rate = sample_rate )],)
156
+ transform = Compose (
157
+ [Reverb (sample_rate = sample_rate )],
158
+ )
101
159
102
160
t_audio = transform (audio )
103
161
# torchaudio.save("tests/reverb.wav", t_audio, sample_rate=sample_rate)
@@ -107,7 +165,9 @@ def test_reverb(num_channels):
107
165
@pytest .mark .parametrize ("num_channels" , [1 , 2 ])
108
166
def test_reverse (num_channels ):
109
167
stereo_audio = generate_waveform (sample_rate , num_samples , num_channels )
110
- transform = Compose ([Reverse ()],)
168
+ transform = Compose (
169
+ [Reverse ()],
170
+ )
111
171
112
172
t_audio = transform (stereo_audio )
113
173
# torchaudio.save("tests/reverse.wav", t_audio, sample_rate=sample_rate)
0 commit comments