Skip to content

Commit

Permalink
rename MyDataset -> TTSDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Jun 28, 2021
1 parent 6c7bbca commit 42554cc
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
Expand Down Expand Up @@ -83,7 +83,7 @@
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = MyDataset(
dataset = TTSDataset(
model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from TTS.config import load_config
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
Expand All @@ -22,7 +22,7 @@


def setup_loader(ap, r, verbose=False):
dataset = MyDataset(
dataset = TTSDataset(
r,
c.text_cleaner,
compute_linear_spec=False,
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/train_align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.utils.data.distributed import DistributedSampler

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import AlignTTSLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
Expand All @@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/train_glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data.distributed import DistributedSampler

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
Expand All @@ -38,7 +38,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/train_speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data.distributed import DistributedSampler

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import SpeedySpeechLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
Expand All @@ -39,7 +39,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not config.run_eval:
loader = None
else:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=False,
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
Expand Down Expand Up @@ -43,7 +43,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
loader = None
else:
if dataset is None:
dataset = MyDataset(
dataset = TTSDataset(
r,
config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron",
Expand Down
10 changes: 5 additions & 5 deletions TTS/tts/datasets/TTSDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence


class MyDataset(Dataset):
class TTSDataset(Dataset):
def __init__(
self,
outputs_per_step,
Expand Down Expand Up @@ -117,12 +117,12 @@ def _load_or_generate_phoneme_sequence(
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank
)
if enable_eos_bos:
Expand Down Expand Up @@ -190,7 +190,7 @@ def _phoneme_worker(args):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes

def compute_input_seq(self, num_workers=0):
Expand Down Expand Up @@ -225,7 +225,7 @@ def compute_input_seq(self, num_workers=0):
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
Expand Down
4 changes: 2 additions & 2 deletions notebooks/ExtractTTSpectrogram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import numpy as np\n",
"from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import MyDataset\n",
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n",
Expand Down Expand Up @@ -112,7 +112,7 @@
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion tests/data_tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs):

def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv")
dataset = TTSDataset.MyDataset(
dataset = TTSDataset.TTSDataset(
r,
c.text_cleaner,
compute_linear_spec=True,
Expand Down

0 comments on commit 42554cc

Please sign in to comment.