Skip to content

Commit

Permalink
Fix the bug in split dataset function (coqui-ai#1251)
Browse files Browse the repository at this point in the history
* Fix the bug in split_dataset

* Make eval_split_size configurable

* Change test_loader to use load_tts_samples function

* Change eval_split_portion to eval_split_size and permits to set the absolute number of samples in eval

* Fix samplers unit test

* Add data unit test on GitHub workflow
  • Loading branch information
Edresson authored Feb 21, 2022
1 parent a19021d commit 28a7464
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 24 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/data_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: data-tests

on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
experimental: [false]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: coqui-ai/setup-python@pip-cache-key-py-ver
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make data_tests
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ test_aux: ## run aux tests.
test_zoo: ## run zoo tests.
nosetests tests.zoo_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.zoo_tests --nologcapture --with-id

data_tests: ## run data tests.
nosetests tests.data_tests -x --with-cov -cov --cover-erase --cover-package TTS tests.data_tests --nologcapture --with-id

test_failed: ## only run tests failed the last time.
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --failed

Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)

# load data instances
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval)
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size)

# use eval and training partitions
meta_data = meta_data_train + meta_data_eval
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/find_unique_chars.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main():
c = load_config(args.config_path)

# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size)

items = train_items + eval_items

Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/find_unique_phonemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main():
c = load_config(args.config_path)

# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size)
items = train_items + eval_items
print("Num items:", len(items))

Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/train_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
config = register_config(config_base.model)()

# load training samples
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True)
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size)

# setup audio processor
ap = AudioProcessor(**config.audio)
Expand Down
10 changes: 10 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ class BaseTTSConfig(BaseTrainingConfig):
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
"""

audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
Expand Down Expand Up @@ -218,3 +225,6 @@ class BaseTTSConfig(BaseTrainingConfig):
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01
41 changes: 32 additions & 9 deletions TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,40 @@
from TTS.tts.datasets.formatters import *


def split_dataset(items):
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
"""
speakers = [item[-1] for item in items]
speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = min(500, int(len(items) * 0.01))
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
if eval_split_size > 1:
eval_split_size = int(eval_split_size)
else:
if eval_split_max_size:
eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size))
else:
eval_split_size = int(len(items) * eval_split_size)

assert eval_split_size > 0, " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(1/len(items))
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
speakers = [item[-1] for item in items]
speakers = [item["speaker_name"] for item in items]
speaker_counter = Counter(speakers)
while len(items_eval) < eval_split_size:
item_idx = np.random.randint(0, len(items))
speaker_to_be_removed = items[item_idx][-1]
speaker_to_be_removed = items[item_idx]["speaker_name"]
if speaker_counter[speaker_to_be_removed] > 1:
items_eval.append(items[item_idx])
speaker_counter[speaker_to_be_removed] -= 1
Expand All @@ -37,7 +52,8 @@ def split_dataset(items):


def load_tts_samples(
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None,
eval_split_max_size=None, eval_split_size=0.01
) -> Tuple[List[List], List[List]]:
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
Expand All @@ -55,6 +71,13 @@ def load_tts_samples(
`[[audio_path, text, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as
example. Defaults to None.
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
Expand Down Expand Up @@ -84,7 +107,7 @@ def load_tts_samples(
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for the duration predictor training
Expand Down
6 changes: 5 additions & 1 deletion TTS/tts/datasets/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r", encoding="utf-8") as ttf:
speaker_id = 0
for idx, line in enumerate(ttf):
# 2 samples per speaker to avoid eval split issues
if idx%2 == 0:
speaker_id += 1
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"})
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
return items


Expand Down
27 changes: 19 additions & 8 deletions tests/data_tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from tests import get_tests_output_path
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.datasets import TTSDataset
from TTS.tts.datasets.formatters import ljspeech
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.utils.audio import AudioProcessor

# pylint: disable=unused-variable
Expand All @@ -18,11 +18,19 @@
os.makedirs(OUTPATH, exist_ok=True)

# create a dummy config for testing data loaders.
c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2)
c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2, use_noise_augment=False)
c.r = 5
c.data_path = "tests/data/ljspeech/"
ok_ljspeech = os.path.exists(c.data_path)

dataset_config = BaseDatasetConfig(
name="ljspeech_test", # ljspeech_test to multi-speaker
meta_file_train="metadata.csv",
meta_file_val=None,
path=c.data_path,
language="en",
)

DATA_EXIST = True
if not os.path.exists(c.data_path):
DATA_EXIST = False
Expand All @@ -37,11 +45,10 @@ def __init__(self, *args, **kwargs):
self.ap = AudioProcessor(**c.audio)

def _create_dataloader(self, batch_size, r, bgs):
items = ljspeech(c.data_path, "metadata.csv")

# add a default language because now the TTSDataset expect a language
language = ""
items = [[*item, language] for item in items]
# load dataset
meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2)
items = meta_data_train + meta_data_eval

dataset = TTSDataset(
r,
Expand Down Expand Up @@ -97,8 +104,12 @@ def test_loader(self):

# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
# remove padding in mel-spectrogram
mel_dataloader = mel_input[0].T.numpy()[:, :mel_lengths[0]]
# guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
mel_new = mel_new[:, :mel_lengths[0]]
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg]
mel_diff = (mel_new - mel_dataloader)[:, 0:ignore_seg]
assert abs(mel_diff.sum()) < 1e-5

# check normalization ranges
Expand Down
4 changes: 2 additions & 2 deletions tests/data_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def is_balanced(lang_1, lang_2):
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == "en":
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
Expand All @@ -50,7 +50,7 @@ def is_balanced(lang_1, lang_2):
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index][3] == "en":
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
Expand Down

0 comments on commit 28a7464

Please sign in to comment.