Skip to content

Commit ffd4142

Browse files
committed
added new tests for ComposeMany
1 parent 4442310 commit ffd4142

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

tests/test_compose.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import pytest
2-
from torchaudio_augmentations import Compose, ComposeMany, RandomResizedCrop
2+
import torch
3+
from torchaudio_augmentations import (
4+
Compose,
5+
ComposeMany,
6+
RandomResizedCrop,
7+
Reverb,
8+
)
39

410
from .utils import generate_waveform
511

@@ -10,7 +16,11 @@
1016
@pytest.mark.parametrize("num_channels", [1, 2])
1117
def test_compose(num_channels):
1218
audio = generate_waveform(sample_rate, num_samples, num_channels)
13-
transform = Compose([RandomResizedCrop(num_samples),])
19+
transform = Compose(
20+
[
21+
RandomResizedCrop(num_samples),
22+
]
23+
)
1424

1525
t_audio = transform(audio)
1626
assert t_audio.shape[0] == num_channels
@@ -19,14 +29,21 @@ def test_compose(num_channels):
1929

2030
@pytest.mark.parametrize("num_channels", [1, 2])
2131
def test_compose_many(num_channels):
22-
num_augmented_samples = 10
32+
num_augmented_samples = 4
2333

2434
audio = generate_waveform(sample_rate, num_samples, num_channels)
2535
transform = ComposeMany(
26-
[RandomResizedCrop(num_samples),], num_augmented_samples=num_augmented_samples,
36+
[
37+
RandomResizedCrop(num_samples),
38+
Reverb(sample_rate),
39+
],
40+
num_augmented_samples=num_augmented_samples,
2741
)
2842

2943
t_audio = transform(audio)
3044
assert t_audio.shape[0] == num_augmented_samples
3145
assert t_audio.shape[1] == num_channels
3246
assert t_audio.shape[2] == num_samples
47+
48+
for n in range(1, num_augmented_samples):
49+
assert torch.all(t_audio[0].eq(t_audio[n])) == False

0 commit comments

Comments
 (0)