This repository has been archived by the owner on Mar 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tacotron2.py
103 lines (85 loc) · 3.42 KB
/
tacotron2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import warnings
from functools import partial
from typing import Any, Optional
from ._api import register, Weights, WeightEntry
from ..transforms.audio_presets import Text2Characters
# Import a few stuff that we plan to keep as-is to avoid copy-pasting
from torchaudio.models.tacotron2 import Tacotron2
__all__ = ["Tacotron2"]
_DEFAULT_PARAMETERS = {
'mask_padding': False,
'n_mels': 80,
'n_frames_per_step': 1,
'symbol_embedding_dim': 512,
'encoder_embedding_dim': 512,
'encoder_n_convolution': 3,
'encoder_kernel_size': 5,
'decoder_rnn_dim': 1024,
'decoder_max_step': 2000,
'decoder_dropout': 0.1,
'decoder_early_stopping': True,
'attention_rnn_dim': 1024,
'attention_hidden_dim': 128,
'attention_location_n_filter': 32,
'attention_location_kernel_size': 31,
'attention_dropout': 0.1,
'prenet_dim': 256,
'postnet_n_convolution': 5,
'postnet_kernel_size': 5,
'postnet_embedding_dim': 512,
'gate_threshold': 0.5,
}
class Tacotron2Weights(Weights):
Characters_LJSpeech = WeightEntry(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_ljspeech.pth',
partial(Text2Characters, symbols="_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"),
{'lang': 'en', 'epochs': 1500, 'n_symbol': 38},
True
)
Characters_WaveRNN_LJSpeech = WeightEntry(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth',
partial(Text2Characters, symbols="_-!'(),.:;? abcdefghijklmnopqrstuvwxyz"),
{'lang': 'en', 'epochs': 1500, 'n_symbol': 38},
True
)
Phonemes_LJSpeech = WeightEntry(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_ljspeech.pth',
None, # Phonemes preprocessing goes here
{'lang': 'en', 'epochs': 1500, 'n_symbol': 96},
True
)
Phonemes_WaveRNN_LJSpeech = WeightEntry(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth',
None, # Phonemes preprocessing goes here
{'lang': 'en', 'epochs': 1500, 'n_symbol': 96},
True
)
@register
def tacotron2(n_symbol: Optional[int] = None, weights: Optional[Tacotron2Weights] = None, progress: bool = False,
**kwargs: Any) -> Tacotron2:
# Backward compatibility for checkpoint_name
if "checkpoint_name" in kwargs:
warnings.warn("The argument checkpoint_name is deprecated, please use weights instead.")
checkpoint_name = kwargs.pop("checkpoint_name")
weights = next((x for x in Tacotron2Weights if os.path.basename(x.url)[:-4] == checkpoint_name), None)
if weights is None:
raise ValueError(f"Unexpected checkpoint_name: '{checkpoint_name}'. ")
# Confirm we got the right weights
Tacotron2Weights.check_type(weights)
if n_symbol is None and weights is None:
raise ValueError("Both n_symbol and weights can't be None.")
# Build parameters by overwriding the defaults
config = {
**_DEFAULT_PARAMETERS,
**kwargs,
}
# Adjust number of symbols if necessary
if weights is not None:
config['n_symbol'] = weights.meta['n_symbol']
# Initialize model
model = Tacotron2(**config)
# Optionally load weights
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model