Skip to content

Commit

Permalink
platform indep. way to fetch user data folder
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jan 26, 2021
1 parent 0117c81 commit 4f32e77
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
19 changes: 19 additions & 0 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import shutil
import subprocess
import sys
from pathlib import Path

import torch

Expand Down Expand Up @@ -67,6 +69,22 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_user_data_dir(appname):
if sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == 'darwin':
ans = Path('~/Library/Application Support/').expanduser()
else:
ans = Path.home().joinpath('.local/share')
return ans.joinpath(appname)


def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
Expand Down Expand Up @@ -97,6 +115,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
len(model_dict)))
return model_dict


class KeepAverage():
def __init__(self):
self.avg_values = {}
Expand Down
4 changes: 2 additions & 2 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from TTS.utils.io import load_config

from TTS.utils.generic_utils import get_user_data_dir

class ModelManager(object):
"""Manage TTS models defined in .models.json.
Expand All @@ -19,7 +19,7 @@ class ModelManager(object):
"""
def __init__(self, models_file):
super().__init__()
self.output_prefix = os.path.join(str(Path.home()), '.tts')
self.output_prefix = get_user_data_dir('tts')
self.url_prefix = "https://drive.google.com/uc?id="
self.models_dict = None
self.read_models_file(models_file)
Expand Down

0 comments on commit 4f32e77

Please sign in to comment.