Skip to content

Commit

Permalink
update compute_attention_masks.py
Browse files Browse the repository at this point in the history
Rate limit · GitHub

Access has been restricted

You have triggered a rate limit.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.

root committed Jan 13, 2021
1 parent 0a9767a commit 7beaacc
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions TTS/bin/compute_attention_masks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
Sample run on LJSpeech dataset.
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
--data_path /home/erogol/Data/LJSpeech-1.1/ \
--batch_size 16 \
--use_cuda true
"""


import argparse
import importlib
import os
@@ -20,6 +6,7 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from argparse import RawTextHelpFormatter
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint
@@ -30,40 +17,52 @@

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Extract attention masks from trained Tacotron models.')
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''

'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''

'''
Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
--data_path /root/LJSpeech-1.1/
--batch_size 32
--dataset ljspeech
--use_cuda True
''',
formatter_class=RawTextHelpFormatter
)
parser.add_argument('--model_path',
type=str,
help='Path to Tacotron or Tacotron2 model file ')
required=True,
help='Path to Tacotron/Tacotron2 model file ')
parser.add_argument(
'--config_path',
type=str,
required=True,
help='Path to config file for training.',
help='Path to Tacotron/Tacotron2 config file.',
)
parser.add_argument('--dataset',
type=str,
default='',
help='Dataset from TTS.tts.dataset.preprocess.')
required=True,
help='Target dataset processor name from TTS.tts.dataset.preprocess.')

parser.add_argument(
'--dataset_metafile',
type=str,
default='',
required=True,
help='Dataset metafile inclusing file paths with transcripts.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')

parser.add_argument('--use_cuda',
type=bool,
default=False,
@@ -148,10 +147,8 @@
mode='nearest',
align_corners=None,
recompute_scale_factor=None).squeeze(0).transpose(0, 1)

# remove paddings
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()

# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
@@ -160,7 +157,7 @@
file_paths.append([item_idx, file_path])
np.save(file_path, alignment)

# ourpur metafile
# ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")

with open(metafile, "w") as f:

0 comments on commit 7beaacc

Please sign in to comment.