-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtts.py
132 lines (104 loc) · 4.36 KB
/
tts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import sys
PATH_TO_TTS_FOLDER = str(os.getcwd()) + "/TTS"
sys.path.append(PATH_TO_TTS_FOLDER)
import torch
from TTS.ForwardTacotron.models.forward_tacotron import ForwardTacotron
from TTS.ForwardTacotron.utils.marathi_grapheme.cleaners import MarathiCleaner
from TTS.ForwardTacotron.utils.marathi_grapheme.tokenizer import MarathiTokenizer
from TTS.ForwardTacotron.utils.hindi_grapheme.cleaners import HindiCleaner
from TTS.ForwardTacotron.utils.hindi_grapheme.tokenizer import HindiTokenizer
import TTS.cargan as cargan
import nvidia_smi
num_devices = torch.cuda.device_count()
device = torch.device('cpu')
nvidia_smi.nvmlInit()
for i in range(0, num_devices):
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
total_memory = info.free / 1073741824
print("Total Memory (TTS): ", total_memory)
if total_memory > 10:
device = torch.device(f'cuda:{i}')
break
else:
torch.cuda.empty_cache()
continue
nvidia_smi.nvmlShutdown()
# TTS Paths
cargan_checkpoint_hi_male = "TTS/TTS/runs/hindi_male/best_netG.pt"
ft_checkpoint_hi_male = "TTS/TTS/ForwardTacotron/checkpoints/hindi_male/hi_grapheme.pt"
cargan_checkpoint_mr_female = "TTS/TTS/runs/marathi_female/best_netG.pt"
ft_checkpoint_mr_female = "TTS/TTS/ForwardTacotron/checkpoints/marathi_female/mr_grapheme.pt"
# Hi Male
## FT
checkpoint_hi_male = torch.load(ft_checkpoint_hi_male, map_location=torch.device('cpu'))
config_hi_male = checkpoint_hi_male['config']
tts_model_hi_male = ForwardTacotron.from_config(config_hi_male)
torch.cuda.empty_cache()
tts_model_hi_male.load_state_dict(checkpoint_hi_male['model'])
tts_model_hi_male.eval().to(device)
## Cargan
cargan_model_hi_male = cargan.load.model(cargan_checkpoint_hi_male).to(device)
# Mr Female
## FT
checkpoint_mr_female = torch.load(ft_checkpoint_mr_female, map_location=torch.device('cpu'))
config_mr_female = checkpoint_hi_male['config']
tts_model_mr_female = ForwardTacotron.from_config(config_mr_female)
tts_model_mr_female.load_state_dict(checkpoint_mr_female['model'])
tts_model_mr_female.eval().to(device)
## Cargan
cargan_model_mr_female = cargan.load.model(cargan_checkpoint_mr_female).to(device)
print("TTS models loaded on device ", device)
def Forwardtacotron(text, amp, lang):
if lang=='mr':
cleaner = MarathiCleaner.from_config(config_mr_female)
tokenizer = MarathiTokenizer()
if lang=='hi':
cleaner = HindiCleaner.from_config(config_hi_male)
tokenizer = HindiTokenizer()
pitch_function = lambda x: x * amp
energy_function = lambda x: x
x = cleaner(text)
x = tokenizer(x)
x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0)
if lang == "hi":
gen = tts_model_hi_male.generate(x=x,
alpha=1.,
pitch_function=pitch_function,
energy_function=energy_function)
elif lang == "mr":
gen = tts_model_mr_female.generate(x=x,
alpha=1.,
pitch_function=pitch_function,
energy_function=energy_function)
if lang == "mr":
mel_output = gen['mel']
else:
mel_output = gen['mel_post']
mel_output = torch.exp(mel_output)
mel_output = torch.log10(mel_output)
return mel_output
def Cargan(features, lang):
if lang == "mr":
with torch.no_grad():
vocoded = cargan.ar_loop(cargan_model_mr_female, features)
elif lang == "hi":
with torch.no_grad():
vocoded = cargan.ar_loop(cargan_model_hi_male, features)
return vocoded.squeeze(0)
def tts(text, lang, amp=1.0):
"""
lang (str): Language of the input text [marathi, hindi, bengali]
gender (str): Gender of the speech to be synthesized [male, female]
text (str): Text that needs to be synthesized
output_file (Optional, str): Name of the output wav file
amp (Optional, float): Pitch scaling factor for text-to-mel
gpu (Optional, int): GPU id to be used [0, 1]
"""
# mel spectrogram file generated by the text-to-mel model
mel_output = Forwardtacotron(text, amp, lang)
# variable vocoded is the final generated speech
vocoded = Cargan(mel_output, lang).cpu()
return vocoded