Skip to content

Commit

Permalink
VALL-E: new version release (#220)
Browse files Browse the repository at this point in the history
A refined version of VALL-E. We have changed the underlying implementation to Llama to provide better model performance, faster training speed, and more readable codes.
  • Loading branch information
jiaqili3 authored Jun 21, 2024
1 parent d335514 commit f96a153
Show file tree
Hide file tree
Showing 31 changed files with 7,092 additions and 28 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
- **TTM**: Text to Music (👨‍💻 developing)
- more…

In addition to the specific generation tasks, Amphion also includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks.
In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks.

Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are generated by our models. Just enjoy it!

[amphion-v0.1-en](https://github.com/open-mmlab/Amphion/assets/24860155/7fcdcea5-3d95-4b31-bd93-4b4da734ef9b
)

## 🚀 News
- **2024/6/17**: Amphion has a new release for its VALL-E models, it uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
- **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
- **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://github.com/open-mmlab/Amphion/assets/33707885/0a6e39e8-d5f1-4288-b0f8-32da5a2d6e96) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md)
- **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39)
Expand All @@ -42,10 +43,10 @@ Here is the Amphion v0.1 demo, whose voice, audio effects, and singing voice are

### TTS: Text to Speech

- Amphion achieves state-of-the-art performance when compared with existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
- Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
- [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks.
- [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning
- [Vall-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
- [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
- [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices.

### SVC: Singing Voice Conversion
Expand Down Expand Up @@ -139,6 +140,7 @@ We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTIN

- [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code.
- [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design.
- [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design.
- [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code.
- [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy.
- [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks.
Expand Down
87 changes: 63 additions & 24 deletions bins/tts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from models.tts.vits.vits_trainer import VITSTrainer
from models.tts.valle.valle_trainer import VALLETrainer
from models.tts.naturalspeech2.ns2_trainer import NS2Trainer
from models.tts.VALLE_V2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR
from models.tts.VALLE_V2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR

from utils.util import load_config


Expand All @@ -20,6 +23,8 @@ def build_trainer(args, cfg):
"VITS": VITSTrainer,
"VALLE": VALLETrainer,
"NaturalSpeech2": NS2Trainer,
"VALLE_V2_AR": VALLE_V2_AR,
"VALLE_V2_NAR": VALLE_V2_NAR,
}

trainer_class = supported_trainer[cfg.model_type]
Expand All @@ -32,6 +37,7 @@ def cuda_relevant(deterministic=False):
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
Expand All @@ -47,6 +53,13 @@ def main():
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="random seed",
required=False,
)
parser.add_argument(
"--exp_name",
type=str,
Expand All @@ -57,6 +70,9 @@ def main():
parser.add_argument(
"--resume", action="store_true", help="The model name to restore"
)
parser.add_argument(
"--test", action="store_true", default=False, help="Test the model"
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
Expand All @@ -72,39 +88,62 @@ def main():
default=None,
help="Checkpoint for resume training or finetuning.",
)

VALLETrainer.add_arguments(parser)
parser.add_argument(
"--resume_from_ckpt_path",
type=str,
default="",
help="Checkpoint for resume training or finetuning.",
)
# VALLETrainer.add_arguments(parser)
args = parser.parse_args()
cfg = load_config(args.config)

# Data Augmentation
if (
type(cfg.preprocess.data_augment) == list
and len(cfg.preprocess.data_augment) > 0
):
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
f"{dataset}_pitch_shift" if cfg.preprocess.use_pitch_shift else None,
(
f"{dataset}_formant_shift"
if cfg.preprocess.use_formant_shift
else None
),
f"{dataset}_equalizer" if cfg.preprocess.use_equalizer else None,
f"{dataset}_time_stretch" if cfg.preprocess.use_time_stretch else None,
]
new_datasets_list.extend(filter(None, new_datasets))
cfg.dataset.extend(new_datasets_list)

if hasattr(cfg, "preprocess"):
if hasattr(cfg.preprocess, "data_augment"):
if (
type(cfg.preprocess.data_augment) == list
and len(cfg.preprocess.data_augment) > 0
):
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
(
f"{dataset}_pitch_shift"
if cfg.preprocess.use_pitch_shift
else None
),
(
f"{dataset}_formant_shift"
if cfg.preprocess.use_formant_shift
else None
),
(
f"{dataset}_equalizer"
if cfg.preprocess.use_equalizer
else None
),
(
f"{dataset}_time_stretch"
if cfg.preprocess.use_time_stretch
else None
),
]
new_datasets_list.extend(filter(None, new_datasets))
cfg.dataset.extend(new_datasets_list)

print("experiment name: ", args.exp_name)
# # CUDA settings
cuda_relevant()

# Build trainer
print(f"Building {cfg.model_type} trainer")
trainer = build_trainer(args, cfg)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
trainer.train_loop()
print(f"Start training {cfg.model_type} model")
if args.test:
trainer.test_loop()
else:
trainer.train_loop()


if __name__ == "__main__":
Expand Down
169 changes: 169 additions & 0 deletions egs/tts/valle_v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# VALL-E
## Introduction
This is an unofficial PyTorch implementation of VALL-E, a zero-shot voice cloning model via neural codec language modeling ([paper link](https://arxiv.org/abs/2301.02111)).
If trained properly, this model could match the performance specified in the original paper.

## Change notes
This is a refined version compared to the first version of VALL-E in Amphion, we have changed the underlying implementation to Llama
to provide better model performance, faster training speed, and more readable codes.
This can be a great tool if you want to learn speech language models and its implementation.

## Installation requirement

Set up your environemnt as in Amphion README (you'll need a conda environment, and we recommend using Linux). A GPU is recommended if you want to train this model yourself.
For inferencing our pretrained models, you could generate samples even without a GPU.
To ensure your transformers library can run the code, we recommend additionally running:
```bash
pip install -U transformers==4.41.2
```

<!-- espeak-ng is required to run G2p. To install it, you could refer to:
https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md
For Linux, it should be `sudo apt-get install espeak-ng`.
For Windows, refer to the above link.
If you do not have sudo privilege, you could build the library by following the last section of this readme. -->

## Inferencing pretrained VALL-E models
### Download pretrained weights
You need to download our pretrained weights from huggingface.

Script to download AR and NAR model checkpoint:
```bash
huggingface-cli download amphion/valle valle_ar_mls_196000.bin valle_nar_mls_164000.bin --local-dir ckpts
```
Script to download codec model (SpeechTokenizer) checkpoint:
```bash
huggingface-cli download amphion/valle speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts
```

### Inference in IPython notebook

We provide our pretrained VALL-E model that is trained on 45k hours MLS dataset.
The "demo.ipynb" file provides a working example of inferencing our pretrained VALL-E model. Give it a try!

## Examining the model files
Examining the model files of VALL-E is a great way to learn how it works.
We provide examples that allows you to overfit a single batch (so no dataset downloading is required).

The AR model is essentially a causal language model that "continues" a speech. The NAR model is a modification from the AR model that allows for bidirectional attention.


File `valle_ar.py` and `valle_nar.py` in "models/tts/VALLE_V2" folder are models files, these files can be run directly via `python -m models.tts.VALLE_V2.valle_ar` (or `python -m models.tts.VALLE_V2.valle_nar`).
This will invoke a test which overfits it to a single example.

## Training VALL-E from scratch
### Preparing LibriTTS or LibriTTS-R dataset files

We have tested our training script on LibriTTS and LibriTTS-R.
You could download LibriTTS-R at [this link](https://www.openslr.org/141/) and LibriTTS at [this link](https://www.openslr.org/60).
The "train-clean-360" split is currently used by our configuration.
You can test dataset.py by run `python -m models.tts.VALLE_V2.libritts_dataset`.

For your reference, our unzipped dataset files has a file structure like this:
```
/path/to/LibriTTS_R
├── BOOKS.txt
├── CHAPTERS.txt
├── dev-clean
│ ├── 2412
│ │ ├── 153947
│ │ │ ├── 2412_153947_000014_000000.normalized.txt
│ │ │ ├── 2412_153947_000014_000000.original.txt
│ │ │ ├── 2412_153947_000014_000000.wav
│ │ │ ├── 2412_153947_000017_000001.normalized.txt
│ │ │ ├── 2412_153947_000017_000001.original.txt
│ │ │ ├── 2412_153947_000017_000001.wav
│ │ │ ├── 2412_153947_000017_000005.normalized.txt
├── train-clean-360
├── 422
│ │ └── 122949
│ │ ├── 422_122949_000009_000007.normalized.txt
│ │ ├── 422_122949_000009_000007.original.txt
│ │ ├── 422_122949_000009_000007.wav
│ │ ├── 422_122949_000013_000010.normalized.txt
│ │ ├── 422_122949_000013_000010.original.txt
│ │ ├── 422_122949_000013_000010.wav
│ │ ├── 422_122949.book.tsv
│ │ └── 422_122949.trans.tsv
```


Alternativelly, you could write your own dataloader for your dataset.
You can reference the `__getitem__` method in `models/tts/VALLE_V2/mls_dataset.py`
It should return a dict of a 1-dimensional tensor 'speech', which is a 16kHz speech; and a 1-dimensional tensor of 'phone', which is the phoneme sequence of the speech.
As long as your dataset returns this in `__getitem__`, it should work.

### Changing batch size and dataset path in configuration file
Our configuration file for training VALL-E AR model is at "egs/tts/VALLE_V2/exp_ar_libritts.json", and NAR model at "egs/tts/VALLE_V2/exp_nar_libritts.json"

To train your model, you need to modify the `dataset` variable in the json configurations.
Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory.
```
"dataset": {
"dataset_list":["train-clean-360"], // You can also change to other splits like "dev-clean"
"data_dir": "/path/to/your/LibriTTS_R",
},
```

You should also select a reasonable batch size at the "batch_size" entry (currently it's set at 5).


You can change other experiment settings in the `/egs/tts/VALLE_V2/exp_ar_libritts.json` such as the learning rate, optimizer and the dataset.

Here we choose `libritts` dataset we added and set `use_dynamic_dataset` false.

Config `use_dynamic_dataset` is used to solve the problem of inconsistent sequence length and improve gpu utilization, here we set it to false for simplicity.

```json
"dataset": {
"use_dynamic_batchsize": false,
"name": "libritts"
},
```

We also recommend changing "num_hidden_layers" if your GPU memory is limited.

**Set smaller batch_size if you are out of memory😢😢**

I used batch_size=3 to successfully run on a single card, if you'r out of memory, try smaller.

```json
"batch_size": 3,
"max_tokens": 11000,
"max_sentences": 64,
"random_seed": 0
```


### Run the command to Train AR model
(Make sure your current directory is at the Amphion root directory).
Run:
```sh
sh egs/tts/VALLE_V2/train_ar_libritts.sh
```
Your model checkpoint could be found in `ckpt/VALLE_V2/ar_libritts/checkpoint/epoch-0000_step-0000000_loss-7.397293/pytorch_model.bin`


### Resume from existing checkpoint
Our framework supports resuming from existing checkpoint.

Run:
```sh
sh egs/tts/VALLE_V2/train_ar_libritts.sh --resume
```

### Run the command to Train NAR model
(Make sure your current directory is at the Amphion root directory).
Run:
```sh
sh egs/tts/VALLE_V2/train_nar_libritts.sh
```

### Inference your models
Since our inference script is already given, you can change the paths
from our pretrained model to you newly trained models and do the inference.

## Future plans
- [ ] Support more languages
- [ ] More are coming...
Loading

0 comments on commit f96a153

Please sign in to comment.