-
Notifications
You must be signed in to change notification settings - Fork 562
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Training & Inference code for FAcodec (#229)
* Training & Inference code for FAcodec * Update vocoder_trainer.py * Added copyright statements & code source (where necessary) * reformatted files with black formatter * reformat * reformat reformat * reformat reformat reformat
- Loading branch information
Showing
41 changed files
with
7,442 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2023 Amphion. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
from argparse import ArgumentParser | ||
import os | ||
|
||
from models.codec.facodec.facodec_inference import FAcodecInference | ||
from utils.util import load_config | ||
import torch | ||
|
||
|
||
def build_inference(args, cfg): | ||
supported_inference = { | ||
"FAcodec": FAcodecInference, | ||
} | ||
|
||
inference_class = supported_inference[cfg.model_type] | ||
inference = inference_class(args, cfg) | ||
return inference | ||
|
||
|
||
def cuda_relevant(deterministic=False): | ||
torch.cuda.empty_cache() | ||
# TF32 on Ampere and above | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.enabled = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
# Deterministic | ||
torch.backends.cudnn.deterministic = deterministic | ||
torch.backends.cudnn.benchmark = not deterministic | ||
torch.use_deterministic_algorithms(deterministic) | ||
|
||
|
||
def build_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--config", | ||
type=str, | ||
required=True, | ||
help="JSON/YAML file for configurations.", | ||
) | ||
parser.add_argument( | ||
"--checkpoint_path", | ||
type=str, | ||
default=None, | ||
help="Acoustic model checkpoint directory. If a directory is given, " | ||
"search for the latest checkpoint dir in the directory. If a specific " | ||
"checkpoint dir is given, directly load the checkpoint.", | ||
) | ||
parser.add_argument( | ||
"--source", | ||
type=str, | ||
required=True, | ||
help="Path to the source audio file", | ||
) | ||
parser.add_argument( | ||
"--reference", | ||
type=str, | ||
default=None, | ||
help="Path to the reference audio file, passing an", | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
default=None, | ||
help="Output dir for saving generated results", | ||
) | ||
return parser | ||
|
||
|
||
def main(): | ||
# Parse arguments | ||
parser = build_parser() | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
# Parse config | ||
cfg = load_config(args.config) | ||
|
||
# CUDA settings | ||
cuda_relevant() | ||
|
||
# Build inference | ||
inferencer = build_inference(args, cfg) | ||
|
||
# Run inference | ||
_ = inferencer.inference(args.source, args.output_dir) | ||
|
||
# Run voice conversion | ||
if args.reference is not None: | ||
_ = inferencer.voice_conversion(args.source, args.reference, args.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2023 Amphion. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
|
||
import torch | ||
|
||
from models.codec.facodec.facodec_trainer import FAcodecTrainer | ||
|
||
from utils.util import load_config | ||
|
||
|
||
def build_trainer(args, cfg): | ||
supported_trainer = { | ||
"FAcodec": FAcodecTrainer, | ||
} | ||
|
||
trainer_class = supported_trainer[cfg.model_type] | ||
trainer = trainer_class(args, cfg) | ||
return trainer | ||
|
||
|
||
def cuda_relevant(deterministic=False): | ||
torch.cuda.empty_cache() | ||
# TF32 on Ampere and above | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.enabled = True | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.allow_tf32 = True | ||
# Deterministic | ||
torch.backends.cudnn.deterministic = deterministic | ||
torch.backends.cudnn.benchmark = not deterministic | ||
torch.use_deterministic_algorithms(deterministic) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
default="config.json", | ||
help="json files for configurations.", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--exp_name", | ||
type=str, | ||
default="exp_name", | ||
help="A specific name to note the experiment", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--resume_type", | ||
type=str, | ||
help="resume for continue to train, finetune for finetuning", | ||
) | ||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
help="checkpoint to resume", | ||
) | ||
parser.add_argument( | ||
"--log_level", default="warning", help="logging level (debug, info, warning)" | ||
) | ||
args = parser.parse_args() | ||
cfg = load_config(args.config) | ||
|
||
# CUDA settings | ||
cuda_relevant() | ||
|
||
# Build trainer | ||
trainer = build_trainer(args, cfg) | ||
|
||
trainer.train_loop() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
{ | ||
"exp_name": "facodec", | ||
"model_type": "FAcodec", | ||
"log_dir": "./runs/", | ||
"log_interval": 10, | ||
"save_interval": 1000, | ||
"device": "cuda", | ||
"epochs": 1000, | ||
"batch_size": 4, | ||
"batch_length": 100, | ||
"max_len": 80, | ||
"pretrained_model": "", | ||
"load_only_params": false, | ||
"F0_path": "modules/JDC/bst.t7", | ||
"dataset": "dummy", | ||
"preprocess_params": { | ||
"sr": 24000, | ||
"frame_rate": 80, | ||
"duration_range": [1.0, 25.0], | ||
"spect_params": { | ||
"n_fft": 2048, | ||
"win_length": 1200, | ||
"hop_length": 300, | ||
"n_mels": 80, | ||
}, | ||
}, | ||
"train": { | ||
"gradient_accumulation_step": 1, | ||
"batch_size": 1, | ||
"save_checkpoint_stride": [20], | ||
"random_seed": 1234, | ||
"max_epoch": -1, | ||
"max_frame_len": 80, | ||
"tracker": ["tensorboard"], | ||
"run_eval": [false], | ||
"sampler": {"holistic_shuffle": true, "drop_last": true}, | ||
"dataloader": {"num_worker": 0, "pin_memory": true}, | ||
}, | ||
"model_params": { | ||
"causal": true, | ||
"lstm": 2, | ||
"norm_f0": true, | ||
"use_gr_content_f0": false, | ||
"use_gr_prosody_phone": false, | ||
"use_gr_timbre_prosody": false, | ||
"separate_prosody_encoder": true, | ||
"n_c_codebooks": 2, | ||
"timbre_norm": true, | ||
"use_gr_content_global_f0": true, | ||
"DAC": { | ||
"encoder_dim": 64, | ||
"encoder_rates": [2, 5, 5, 6], | ||
"decoder_dim": 1536, | ||
"decoder_rates": [6, 5, 5, 2], | ||
"sr": 24000, | ||
}, | ||
}, | ||
"loss_params": { | ||
"base_lr": 0.0001, | ||
"warmup_steps": 200, | ||
"discriminator_iter_start": 2000, | ||
"lambda_spk": 1.0, | ||
"lambda_mel": 45, | ||
"lambda_f0": 1.0, | ||
"lambda_uv": 1.0, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# FAcodec | ||
|
||
Pytorch implementation for the training of FAcodec, which was proposed in paper [NaturalSpeech 3: Zero-Shot Speech Synthesis | ||
with Factorized Codec and Diffusion Models](https://arxiv.org/pdf/2403.03100) | ||
|
||
A dedicated repository for the FAcodec model can also be find [here](https://github.com/Plachtaa/FAcodec). | ||
|
||
This implementation made some key improvements to the training pipeline, so that the requirements of any form of annotations, including | ||
transcripts, phoneme alignments, and speaker labels, are eliminated. All you need are simply raw speech files. | ||
With the new training pipeline, it is possible to train the model on more languages with more diverse timbre distributions. | ||
We release the code for training and inference, including a pretrained checkpoint on 50k hours speech data with over 1 million speakers. | ||
|
||
## Model storage | ||
We provide pretrained checkpoints on 50k hours speech data. | ||
|
||
| Model type | Link | | ||
|-------------------|----------------------------------------------------------------------------------------------------------------------------------------| | ||
| FAcodec | [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-FAcodec-blue)](https://huggingface.co/Plachta/FAcodec) | | ||
|
||
## Demo | ||
Try our model on [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/Plachta/FAcodecV2)! | ||
|
||
## Training | ||
Prepare your data and put them under one folder, internal file structure does not matter. | ||
Then, change the `dataset` in `./egs/codec/FAcodec/exp_custom_data.json` to the path of your data folder. | ||
Finally, run the following command: | ||
```bash | ||
sh ./egs/codec/FAcodec/train.sh | ||
``` | ||
|
||
## Inference | ||
To reconstruct a speech file, run: | ||
```bash | ||
python ./bins/codec/inference.py --source <source_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path> | ||
``` | ||
To use zero-shot voice conversion, run: | ||
```bash | ||
python ./bins/codec/inference.py --source <source_wav> --reference <reference_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path> | ||
``` | ||
|
||
## Feature extraction | ||
When running `./bins/codec/inference.py`, check the returned results of the `FAcodecInference` class: a tuple of `(quantized, codes)` | ||
- `quantized` is the quantized representation of the input speech file. | ||
- `quantized[0]` is the quantized representation of prosody | ||
- `quantized[1]` is the quantized representation of content | ||
|
||
- `codes` is the discrete code representation of the input speech file. | ||
- `codes[0]` is the discrete code representation of prosody | ||
- `codes[1]` is the discrete code representation of content | ||
|
||
For the most clean content representation without any timbre, we suggest to use `codes[1][:, 0, :]`, which is the first layer of content codebooks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
{ | ||
"exp_name": "facodec", | ||
"model_type": "FAcodec", | ||
|
||
"log_dir": "./runs/", | ||
"log_interval": 10, | ||
"save_interval": 1000, | ||
"device": "cuda", | ||
"epochs": 1000, | ||
"batch_size": 4, | ||
"batch_length": 100, | ||
"max_len": 80, | ||
"pretrained_model": "", | ||
"load_only_params": false, | ||
"F0_path": "modules/JDC/bst.t7", | ||
"dataset": "/path/to/dataset", | ||
"preprocess_params": { | ||
"sr": 24000, | ||
"frame_rate": 80, | ||
"duration_range": [1.0, 25.0], | ||
"spect_params": { | ||
"n_fft": 2048, | ||
"win_length": 1200, | ||
"hop_length": 300, | ||
"n_mels": 80 | ||
} | ||
}, | ||
"train": { | ||
"gradient_accumulation_step": 1, | ||
"batch_size": 1, | ||
"save_checkpoint_stride": [ | ||
20 | ||
], | ||
"random_seed": 1234, | ||
"max_epoch": -1, | ||
"max_frame_len": 80, | ||
"tracker": [ | ||
"tensorboard" | ||
], | ||
"run_eval": [ | ||
false | ||
], | ||
"sampler": { | ||
"holistic_shuffle": true, | ||
"drop_last": true | ||
}, | ||
"dataloader": { | ||
"num_worker": 0, | ||
"pin_memory": true | ||
} | ||
}, | ||
"model_params": { | ||
"causal": true, | ||
"lstm": 2, | ||
"norm_f0": true, | ||
"use_gr_content_f0": false, | ||
"use_gr_prosody_phone": false, | ||
"use_gr_timbre_prosody": false, | ||
"separate_prosody_encoder": true, | ||
"n_c_codebooks": 2, | ||
"timbre_norm": true, | ||
"use_gr_content_global_f0": true, | ||
"DAC": { | ||
"encoder_dim": 64, | ||
"encoder_rates": [2, 5, 5, 6], | ||
"decoder_dim": 1536, | ||
"decoder_rates": [6, 5, 5, 2], | ||
"sr": 24000 | ||
} | ||
}, | ||
"loss_params": { | ||
"base_lr": 0.0001, | ||
"warmup_steps": 200, | ||
"discriminator_iter_start": 2000, | ||
"lambda_spk": 1.0, | ||
"lambda_mel": 45, | ||
"lambda_f0": 1.0, | ||
"lambda_uv": 1.0 | ||
} | ||
} |
Oops, something went wrong.