Skip to content

Commit 7ef8e9a

Browse files
authored
Fix tests (#168)
1 parent 8d44c8e commit 7ef8e9a

File tree

7 files changed

+27
-407
lines changed

7 files changed

+27
-407
lines changed

analytics/tests/tests/models/test_tacotron2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ def test_tacotron2_model(self):
2020
with open("analytics/tests/fixtures/ljtest/taco2_lj2lj.json") as f:
2121
config.update(json.load(f))
2222
hparams = HParams(**config)
23-
if torch.cuda.is_available():
24-
hparams.cudnn_enabled = True
23+
24+
hparams.cudnn_enabled = False
2525
model = Tacotron2(hparams)
26-
if torch.cuda.is_available():
27-
model.cuda()
2826

2927
trainer = Tacotron2Trainer(hparams, rank=0, world_size=0)
3028
(
@@ -43,16 +41,16 @@ def test_tacotron2_model(self):
4341
"speaker_ids",
4442
"gst",
4543
"mel_padded",
46-
"output_lengths",
44+
"mel_lengths",
4745
]
4846
)
4947
model_output = model(
5048
input_text=model_input["text_int_padded"],
5149
input_lengths=model_input["input_lengths"],
5250
speaker_ids=model_input["speaker_ids"],
53-
embedded_gst=model_input["gst"],
51+
embedded_gst=model_input.get("gst", None),
5452
targets=model_input["mel_padded"],
55-
output_lengths=model_input["output_lengths"],
53+
output_lengths=model_input["mel_lengths"],
5654
)
5755

5856
# 'mel_outputs', 'mel_outputs_postnet', 'gate_predicted', 'output_lengths', 'alignments'

analytics/tests/tests/test_data_loader.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_batch_structure(self):
3131
collate_fn = Collate()
3232
dl = DataLoader(ds, 12, collate_fn=collate_fn)
3333
for i, batch in enumerate(dl):
34-
assert len(batch) == 9
34+
assert len(batch) == 6
3535

3636
def test_batch_dimensions(self):
3737
ds = Data(
@@ -44,13 +44,13 @@ def test_batch_dimensions(self):
4444
collate_fn = Collate()
4545
dl = DataLoader(ds, 12, collate_fn=collate_fn)
4646
for i, batch in enumerate(dl):
47-
output_lengths = batch["output_lengths"]
48-
gate_target = batch["gate_target"]
47+
output_lengths = batch["mel_lengths"]
48+
gate_target = batch["gate_padded"]
4949
mel_padded = batch["mel_padded"]
5050
assert output_lengths.item() == 566
5151
assert gate_target.size(1) == 566
5252
assert mel_padded.size(2) == 566
53-
assert len(batch) == 9
53+
assert len(batch) == 6
5454

5555
def test_batch_dimensions_partial(self):
5656
ds = Data(
@@ -63,9 +63,9 @@ def test_batch_dimensions_partial(self):
6363
collate_fn = Collate(n_frames_per_step=5)
6464
dl = DataLoader(ds, 12, collate_fn=collate_fn)
6565
for i, batch in enumerate(dl):
66-
assert batch["output_lengths"].item() == 566
66+
assert batch["mel_lengths"].item() == 566
6767
assert (
6868
batch["mel_padded"].size(2) == 566
6969
) # I'm not sure why this was 570 - maybe 566 + 5 (i.e. the n_frames_per_step)
70-
assert batch["gate_target"].size(1) == 566
71-
assert len(batch) == 9
70+
assert batch["gate_padded"].size(1) == 566
71+
assert len(batch) == 6

analytics/tests/tests/trainer/test_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from uberduck_ml_dev.trainer.base import TTSTrainer
21
import torch
32
import math
3+
44
from uberduck_ml_dev.vendor.tfcompat.hparam import HParams
55
from uberduck_ml_dev.trainer.base import DEFAULTS as TRAINER_DEFAULTS
6+
from uberduck_ml_dev.trainer.base import TTSTrainer
7+
from uberduck_ml_dev.models.common import MelSTFT
68

79

810
class TestTrainer:
@@ -22,7 +24,8 @@ def test_trainer_base(self):
2224

2325
assert trainer.cudnn_enabled == True
2426
mel = torch.load("analytics/tests/fixtures/stevejobs-1.pt")
25-
audio = trainer.sample(mel)
27+
mel_stft = MelSTFT()
28+
audio = mel_stft.griffin_lim(mel)
2629
assert audio.size(0) == 1
2730

2831

uberduck_ml_dev/data/data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@
5151
F0_MAX = 640
5252

5353

54-
# NOTE (Sam): generic dataset class for all purposes avoids writing redundant methods (e.g. get pitch when text isn't available).
55-
# However, functional factorization of this dataloader (e.g. get_mels) and merging classes as needed would be preferable.
56-
# NOTE (Sam): "load" means load from file. "return" means return to collate. "get" is a functional element. "has" and "with" are tbd equivalent in trainer/model.
54+
# TODO (Sam): replace with Dataset.
5755
class Data(Dataset):
5856
def __init__(
5957
self,

uberduck_ml_dev/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, pos_weight):
1919

2020
# NOTE (Sam): making function inputs explicit makes less sense in situations like this with obvious subcategories.
2121
def forward(self, model_output: Batch, target: Batch):
22-
mel_target, gate_target = target["mel_padded"], target["gate_target"]
22+
mel_target, gate_target = target["mel_padded"], target["gate_padded"]
2323
mel_target.requires_grad = False
2424
gate_target.requires_grad = False
2525
mel_out, mel_out_postnet, gate_out = (

0 commit comments

Comments
 (0)