From bbea6a08840598e0a4e655007b2e9577b132f501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 29 Jan 2021 15:00:33 +0000 Subject: [PATCH] hubconf.py and load .models.json from the defualt location by mange.py --- TTS/hubconf.py | 26 ++++++++++++++++++++++++++ TTS/utils/manage.py | 16 +++++++++++----- 2 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 TTS/hubconf.py diff --git a/TTS/hubconf.py b/TTS/hubconf.py new file mode 100644 index 0000000000..c4e5bc99d2 --- /dev/null +++ b/TTS/hubconf.py @@ -0,0 +1,26 @@ +dependencies = ['torch', 'gdown'] +import torch +import os +import zipfile + +from TTS.utils.generic_utils import get_user_data_dir +from TTS.utils.synthesizer import Synthesizer +from TTS.utils.manage import ModelManager + + + +def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', pretrained=True): + manager = ModelManager() + + model_path, config_path = manager.download_model(model_name) + vocoder_path, vocoder_config_path = manager.download_model(vocoder_name) + + # create synthesizer + synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path) + return synthesizer + + +if __name__ == '__main__': + # synthesizer = torch.hub.load('/data/rw/home/projects/TTS/TTS', 'tts', source='local') + synthesizer = torch.hub.load('mozilla/TTS:hub_conf', 'tts', source='github') + synthesizer.tts("This is a test!") \ No newline at end of file diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index af741156a8..3cf8d67f72 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,10 +1,11 @@ import json -import gdown -from pathlib import Path import os +from pathlib import Path -from TTS.utils.io import load_config +import gdown from TTS.utils.generic_utils import get_user_data_dir +from TTS.utils.io import load_config + class ModelManager(object): """Manage TTS models defined in .models.json. @@ -17,12 +18,17 @@ class ModelManager(object): Args: models_file (str): path to .model.json """ - def __init__(self, models_file): + def __init__(self, models_file=None): super().__init__() self.output_prefix = get_user_data_dir('tts') self.url_prefix = "https://drive.google.com/uc?id=" self.models_dict = None - self.read_models_file(models_file) + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) def read_models_file(self, file_path): """Read .models.json as a dict