-
Notifications
You must be signed in to change notification settings - Fork 805
/
pe.py
155 lines (137 loc) · 6.03 KB
/
pe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import matplotlib
matplotlib.use('Agg')
import torch
import numpy as np
import os
from training.dataset.base_dataset import BaseDataset
from training.task.fs2 import FastSpeech2Task
from modules.fastspeech.pe import PitchExtractor
import utils
from utils.indexed_datasets import IndexedDataset
from utils.hparams import hparams
from utils.plot import f0_to_figure
from utils.pitch_utils import norm_interp_f0, denorm_f0
class PeDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.data_dir = hparams['binary_data_dir']
self.prefix = prefix
self.hparams = hparams
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.indexed_ds = None
# pitch stats
f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
if os.path.exists(f0_stats_fn):
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
hparams['f0_mean'] = float(hparams['f0_mean'])
hparams['f0_std'] = float(hparams['f0_std'])
else:
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
if prefix == 'test':
if hparams['num_test_samples'] > 0:
self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
self.sizes = [self.sizes[i] for i in self.avail_idxs]
def _get_item(self, index):
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
index = self.avail_idxs[index]
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
return self.indexed_ds[index]
def __getitem__(self, index):
hparams = self.hparams
item = self._get_item(index)
max_frames = hparams['max_frames']
spec = torch.Tensor(item['mel'])[:max_frames]
# mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
# print(item.keys(), item['mel'].shape, spec.shape)
sample = {
"id": index,
"item_name": item['item_name'],
"text": item['txt'],
"mel": spec,
"pitch": pitch,
"f0": f0,
"uv": uv,
# "mel2ph": mel2ph,
# "mel_nonpadding": spec.abs().sum(-1) > 0,
}
return sample
def collater(self, samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s['id'] for s in samples])
item_names = [s['item_name'] for s in samples]
text = [s['text'] for s in samples]
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
pitch = utils.collate_1d([s['pitch'] for s in samples])
uv = utils.collate_1d([s['uv'] for s in samples])
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
# mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
# if samples[0]['mel2ph'] is not None else None
# mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0)
batch = {
'id': id,
'item_name': item_names,
'nsamples': len(samples),
'text': text,
'mels': mels,
'mel_lengths': mel_lengths,
'pitch': pitch,
# 'mel2ph': mel2ph,
# 'mel_nonpaddings': mel_nonpaddings,
'f0': f0,
'uv': uv,
}
return batch
class PitchExtractionTask(FastSpeech2Task):
def __init__(self):
super().__init__()
self.dataset_cls = PeDataset
def build_tts_model(self):
self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
# def build_scheduler(self, optimizer):
# return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
def _training_step(self, sample, batch_idx, _):
loss_output = self.run_model(self.model, sample)
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
loss_output['batch_size'] = sample['mels'].size()[0]
return total_loss, loss_output
def validation_step(self, sample, batch_idx):
outputs = {}
outputs['losses'] = {}
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
outputs['total_loss'] = sum(outputs['losses'].values())
outputs['nsamples'] = sample['nsamples']
outputs = utils.tensors_to_scalars(outputs)
if batch_idx < hparams['num_valid_plots']:
self.plot_pitch(batch_idx, model_out, sample)
return outputs
def run_model(self, model, sample, return_output=False, infer=False):
f0 = sample['f0']
uv = sample['uv']
output = model(sample['mels'])
losses = {}
self.add_pitch_loss(output, sample, losses)
if not return_output:
return losses
else:
return losses, output
def plot_pitch(self, batch_idx, model_out, sample):
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
self.logger.experiment.add_figure(
f'f0_{batch_idx}',
f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
self.global_step)
def add_pitch_loss(self, output, sample, losses):
# mel2ph = sample['mel2ph'] # [B, T_s]
mel = sample['mels']
f0 = sample['f0']
uv = sample['uv']
# nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
# else (sample['txt_tokens'] != 0).float()
nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings']
# print(nonpadding[0][-8:], nonpadding.shape)
self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)