Skip to content

Commit

Permalink
Fix VALLE inference bug (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqili3 authored Jul 26, 2024
1 parent 72112a6 commit a17f139
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 39 deletions.
55 changes: 18 additions & 37 deletions egs/tts/VALLE_V2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ To ensure your transformers library can run the code, we recommend additionally
pip install -U transformers==4.41.2
```

<!-- espeak-ng is required to run G2p. To install it, you could refer to:
https://github.com/espeak-ng/espeak-ng/blob/master/docs/guide.md
For Linux, it should be `sudo apt-get install espeak-ng`.
For Windows, refer to the above link.
If you do not have sudo privilege, you could build the library by following the last section of this readme. -->

## Inferencing pretrained VALL-E models
### Download pretrained weights
You need to download our pretrained weights from huggingface.
Expand All @@ -34,12 +27,21 @@ huggingface-cli download amphion/valle valle_ar_mls_196000.bin valle_nar_mls_164
```
Script to download codec model (SpeechTokenizer) checkpoint:
```bash
huggingface-cli download amphion/valle speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts
mkdir -p ckpts/speechtokenizer_hubert_avg && huggingface-cli download amphion/valle SpeechTokenizer.pt config.json --local-dir ckpts/speechtokenizer_hubert_avg
```

If you cannot access huggingface, consider using the huggingface mirror to download:
```bash
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download amphion/valle valle_ar_mls_196000.bin valle_nar_mls_164000.bin --local-dir ckpts
```
```bash
mkdir -p ckpts/speechtokenizer_hubert_avg && HF_ENDPOINT=https://hf-mirror.com huggingface-cli download amphion/valle SpeechTokenizer.pt config.json --local-dir ckpts/speechtokenizer_hubert_avg
```


### Inference in IPython notebook

We provide our pretrained VALL-E model that is trained on 45k hours MLS dataset.
We provide our pretrained VALL-E model that is trained on 45k hours MLS dataset, which contains 10-20s English speech.
The "demo.ipynb" file provides a working example of inferencing our pretrained VALL-E model. Give it a try!

## Examining the model files
Expand All @@ -49,7 +51,7 @@ We provide examples that allows you to overfit a single batch (so no dataset dow
The AR model is essentially a causal language model that "continues" a speech. The NAR model is a modification from the AR model that allows for bidirectional attention.


File `valle_ar.py` and `valle_nar.py` in "models/tts/VALLE_V2" folder are models files, these files can be run directly via `python -m models.tts.VALLE_V2.valle_ar` (or `python -m models.tts.VALLE_V2.valle_nar`).
File `valle_ar.py` and `valle_nar.py` in "models/tts/valle_v2" folder are models files, these files can be run directly via `python -m models.tts.valle_v2.valle_ar` (or `python -m models.tts.valle_v2.valle_nar`).
This will invoke a test which overfits it to a single example.

## Training VALL-E from scratch
Expand All @@ -58,7 +60,7 @@ This will invoke a test which overfits it to a single example.
We have tested our training script on LibriTTS and LibriTTS-R.
You could download LibriTTS-R at [this link](https://www.openslr.org/141/) and LibriTTS at [this link](https://www.openslr.org/60).
The "train-clean-360" split is currently used by our configuration.
You can test dataset.py by run `python -m models.tts.VALLE_V2.libritts_dataset`.
You can test dataset.py by run `python -m models.tts.valle_v2.libritts_dataset`.

For your reference, our unzipped dataset files has a file structure like this:
```
Expand Down Expand Up @@ -111,38 +113,13 @@ You should also select a reasonable batch size at the "batch_size" entry (curren

You can change other experiment settings in the `/egs/tts/VALLE_V2/exp_ar_libritts.json` such as the learning rate, optimizer and the dataset.

Here we choose `libritts` dataset we added and set `use_dynamic_dataset` false.

Config `use_dynamic_dataset` is used to solve the problem of inconsistent sequence length and improve gpu utilization, here we set it to false for simplicity.

```json
"dataset": {
"use_dynamic_batchsize": false,
"name": "libritts"
},
```

We also recommend changing "num_hidden_layers" if your GPU memory is limited.

**Set smaller batch_size if you are out of memory😢😢**

I used batch_size=3 to successfully run on a single card, if you'r out of memory, try smaller.

```json
"batch_size": 3,
"max_tokens": 11000,
"max_sentences": 64,
"random_seed": 0
```


### Run the command to Train AR model
(Make sure your current directory is at the Amphion root directory).
Run:
```sh
sh egs/tts/VALLE_V2/train_ar_libritts.sh
```
Your model checkpoint could be found in `ckpt/VALLE_V2/ar_libritts/checkpoint/epoch-0000_step-0000000_loss-7.397293/pytorch_model.bin`
Your initial model checkpoint could be found in places such as `ckpt/VALLE_V2/ar_libritts/checkpoint/epoch-0000_step-0000000_loss-7.397293/pytorch_model.bin`


### Resume from existing checkpoint
Expand All @@ -153,6 +130,10 @@ Run:
sh egs/tts/VALLE_V2/train_ar_libritts.sh --resume
```

### Finetuning based on our AR model
We provide our AR model optimizer, and random_states checkpoints to support finetuning (No need to download these files if you're only inferencing from the pretrained model). First rename the models as "pytorch_model.bin", "optimizer.bin", and "random_states_0.pkl", then you could resume from these checkpoints. [Link to AR optimizer checkpoint](https://huggingface.co/amphion/valle/blob/main/optimizer_valle_ar_mls_196000.bin) and [Link to random_states.pkl](https://huggingface.co/amphion/valle/blob/main/random_states_0.pkl).


### Run the command to Train NAR model
(Make sure your current directory is at the Amphion root directory).
Run:
Expand Down
4 changes: 2 additions & 2 deletions egs/tts/VALLE_V2/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"# prepare inference data\n",
"import librosa\n",
"import torch\n",
"wav, _ = librosa.load('./egs/tts/valle_v2/example.wav', sr=16000)\n",
"wav, _ = librosa.load('./egs/tts/VALLE_V2/example.wav', sr=16000)\n",
"wav = torch.tensor(wav, dtype=torch.float32)\n",
"from IPython.display import Audio\n",
"Audio(wav, rate = 16000)"
Expand Down Expand Up @@ -235,7 +235,7 @@
"outputs": [],
"source": [
"import torchaudio\n",
"torchaudio.save('out.wav', output_wav.squeeze(0), 24000)"
"torchaudio.save('out.wav', output_wav.squeeze(0), 16000)"
]
}
],
Expand Down

0 comments on commit a17f139

Please sign in to comment.