Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for model_info in CLI #1623

Merged
merged 19 commits into from
Jun 20, 2022
63 changes: 52 additions & 11 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def main():
$ tts --list_models
```

- Query info for model info by idx:

```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```

- Query info for model info by full name:

```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```

- Run TTS with default models:

```
Expand All @@ -48,7 +60,7 @@ def main():
- Run a TTS model with its default vocoder model:

```
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>"
```

- Run with specific TTS and vocoder models from the list:
Expand Down Expand Up @@ -104,6 +116,21 @@ def main():
default=False,
help="list available pre-trained TTS and vocoder models.",
)

parser.add_argument(
"--model_info_by_idx",
type=str,
default=None,
help="model info using query format: <model_type>/<model_query_idx>",
)

parser.add_argument(
"--model_info_by_name",
type=str,
default=None,
help="model info using query format: <model_type>/<language>/<dataset>/<model_name>",
)

parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")

# Args for running pre-trained TTS models.
Expand Down Expand Up @@ -214,13 +241,16 @@ def main():
args = parser.parse_args()

# print the description if either text or list_models is not set
if (
not args.text
and not args.list_models
and not args.list_speaker_idxs
and not args.list_language_idxs
and not args.reference_wav
):
check_args = [
args.text,
args.list_models,
args.list_speaker_idxs,
args.list_language_idxs,
args.reference_wav,
args.model_info_by_idx,
args.model_info_by_name,
]
if not any(check_args):
parser.parse_args(["-h"])

# load model manager
Expand All @@ -236,20 +266,31 @@ def main():
encoder_path = None
encoder_config_path = None

# CASE1: list pre-trained TTS models
# CASE1 #list : list pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()

# CASE2: load pre-trained model paths
# CASE2 #info : model info of pre-trained TTS models
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
sys.exit()

if args.model_info_by_name:
model_query_full_name = args.model_info_by_name
manager.model_info_by_full_name(model_query_full_name)
sys.exit()

# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name

if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)

# CASE3: set custom model paths
# CASE4: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
Expand Down
75 changes: 75 additions & 0 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,81 @@ def list_models(self):
models_name_list.extend(model_list)
return models_name_list

def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx

Args:
model_query (str): <model_tye>/<model_idx>
"""
model_name_list = []
model_type, model_query_idx = model_query.split("/")
try:
model_query_idx = int(model_query_idx)
if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!")
return
except:
print("> model_query_idx should be an integer!")
return
model_count = 0
if model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
else:
print(f"> model_type {model_type} does not exist in the list.")
return
if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ")
else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")

def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name

Args:
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
"""
model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict:
if lang in self.models_dict[model_type]:
if dataset in self.models_dict[model_type][lang]:
if model in self.models_dict[model_type][lang][dataset]:
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
)
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")

def list_tts_models(self):
"""Print all `TTS` models and return a list of model names

Expand Down