Skip to content

Commit

Permalink
Update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Oct 21, 2021
1 parent 5e0d053 commit 016803b
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 69 deletions.
120 changes: 60 additions & 60 deletions notebooks/ExtractTTSpectrogram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand All @@ -20,7 +22,7 @@
"import numpy as np\n",
"from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
Expand All @@ -33,13 +35,13 @@
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
Expand All @@ -51,13 +53,13 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
Expand All @@ -77,13 +79,13 @@
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n",
Expand All @@ -95,13 +97,13 @@
"# TODO: multiple speaker\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
Expand All @@ -120,20 +122,20 @@
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate model outputs "
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
Expand Down Expand Up @@ -212,89 +214,89 @@
"\n",
" print(np.mean(losses))\n",
" print(np.mean(postnet_losses))"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Check"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot decoder output\n",
"print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot GT specgrogram\n",
"print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
Expand All @@ -303,13 +305,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
Expand All @@ -318,13 +320,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
Expand All @@ -334,22 +336,23 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"metadata": {},
"outputs": [],
"metadata": {}
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.7 64-bit ('base': conda)"
"display_name": "Python 3.9.7 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -362,11 +365,8 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}
7 changes: 2 additions & 5 deletions notebooks/PlotUmapLibriTTS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
"source": [
"import os\n",
"import glob\n",
"import random\n",
"import numpy as np\n",
"import torch\n",
"import umap\n",
"\n",
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.generic_utils import load_config\n",
"from TTS.config import load_config\n",
"\n",
"from bokeh.io import output_notebook, show\n",
"from bokeh.plotting import figure\n",
"from bokeh.models import HoverTool, ColumnDataSource, BoxZoomTool, ResetTool, OpenURL, TapTool\n",
"from bokeh.transform import factor_cmap, factor_mark\n",
"from bokeh.transform import factor_cmap\n",
"from bokeh.palettes import Category10"
]
},
Expand Down
1 change: 0 additions & 1 deletion notebooks/dataset_analysis/AnalyzeDataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"import os\n",
"import sys\n",
"sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n",
"import glob\n",
"import librosa\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down
5 changes: 2 additions & 3 deletions notebooks/dataset_analysis/CheckDatasetSNR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"import os\n",
"import glob\n",
"import subprocess\n",
"import tempfile\n",
"import IPython\n",
"import soundfile as sf\n",
"import numpy as np\n",
Expand Down Expand Up @@ -208,4 +207,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Loading

0 comments on commit 016803b

Please sign in to comment.