-
Notifications
You must be signed in to change notification settings - Fork 89
/
inference.py
109 lines (96 loc) · 3.59 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import numpy as np
import sys
import os
import torch.nn as nn
import torch.nn.functional as F
import yaml
import pickle
from model import AE
from utils import *
from functools import reduce
import json
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from argparse import ArgumentParser, Namespace
from scipy.io.wavfile import write
import random
from preprocess.tacotron.utils import melspectrogram2wav
from preprocess.tacotron.utils import get_spectrograms
import librosa
class Inferencer(object):
def __init__(self, config, args):
# config store the value of hyperparameters, turn to attr by AttrDict
self.config = config
print(config)
# args store other information
self.args = args
print(self.args)
# init the model with config
self.build_model()
# load model
self.load_model()
with open(self.args.attr, 'rb') as f:
self.attr = pickle.load(f)
def load_model(self):
print(f'Load model from {self.args.model}')
self.model.load_state_dict(torch.load(f'{self.args.model}'))
return
def build_model(self):
# create model, discriminator, optimizers
self.model = cc(AE(self.config))
print(self.model)
self.model.eval()
return
def utt_make_frames(self, x):
frame_size = self.config['data_loader']['frame_size']
remains = x.size(0) % frame_size
if remains != 0:
x = F.pad(x, (0, remains))
out = x.view(1, x.size(0) // frame_size, frame_size * x.size(1)).transpose(1, 2)
return out
def inference_one_utterance(self, x, x_cond):
x = self.utt_make_frames(x)
x_cond = self.utt_make_frames(x_cond)
dec = self.model.inference(x, x_cond)
dec = dec.transpose(1, 2).squeeze(0)
dec = dec.detach().cpu().numpy()
dec = self.denormalize(dec)
wav_data = melspectrogram2wav(dec)
return wav_data, dec
def denormalize(self, x):
m, s = self.attr['mean'], self.attr['std']
ret = x * s + m
return ret
def normalize(self, x):
m, s = self.attr['mean'], self.attr['std']
ret = (x - m) / s
return ret
def write_wav_to_file(self, wav_data, output_path):
write(output_path, rate=self.args.sample_rate, data=wav_data)
return
def inference_from_path(self):
src_mel, _ = get_spectrograms(self.args.source)
tar_mel, _ = get_spectrograms(self.args.target)
src_mel = torch.from_numpy(self.normalize(src_mel)).cuda()
tar_mel = torch.from_numpy(self.normalize(tar_mel)).cuda()
conv_wav, conv_mel = self.inference_one_utterance(src_mel, tar_mel)
self.write_wav_to_file(conv_wav, self.args.output)
return
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-attr', '-a', help='attr file path')
parser.add_argument('-config', '-c', help='config file path')
parser.add_argument('-model', '-m', help='model path')
parser.add_argument('-source', '-s', help='source wav path')
parser.add_argument('-target', '-t', help='target wav path')
parser.add_argument('-output', '-o', help='output wav path')
parser.add_argument('-sample_rate', '-sr', help='sample rate', default=24000, type=int)
args = parser.parse_args()
# load config file
with open(args.config) as f:
config = yaml.load(f)
inferencer = Inferencer(config=config, args=args)
inferencer.inference_from_path()