Skip to content

Commit

Permalink
Training & Inference code for FAcodec (#229)
Browse files Browse the repository at this point in the history
* Training & Inference code for FAcodec

* Update vocoder_trainer.py

* Added copyright statements & code source (where necessary)

* reformatted files with black formatter

* reformat

* reformat reformat

* reformat reformat reformat
  • Loading branch information
Plachtaa authored Aug 4, 2024
1 parent a17f139 commit 211e1d4
Show file tree
Hide file tree
Showing 41 changed files with 7,442 additions and 2 deletions.
99 changes: 99 additions & 0 deletions bins/codec/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from argparse import ArgumentParser
import os

from models.codec.facodec.facodec_inference import FAcodecInference
from utils.util import load_config
import torch


def build_inference(args, cfg):
supported_inference = {
"FAcodec": FAcodecInference,
}

inference_class = supported_inference[cfg.model_type]
inference = inference_class(args, cfg)
return inference


def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)


def build_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--source",
type=str,
required=True,
help="Path to the source audio file",
)
parser.add_argument(
"--reference",
type=str,
default=None,
help="Path to the reference audio file, passing an",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output dir for saving generated results",
)
return parser


def main():
# Parse arguments
parser = build_parser()
args = parser.parse_args()
print(args)

# Parse config
cfg = load_config(args.config)

# CUDA settings
cuda_relevant()

# Build inference
inferencer = build_inference(args, cfg)

# Run inference
_ = inferencer.inference(args.source, args.output_dir)

# Run voice conversion
if args.reference is not None:
_ = inferencer.voice_conversion(args.source, args.reference, args.output_dir)


if __name__ == "__main__":
main()
79 changes: 79 additions & 0 deletions bins/codec/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import torch

from models.codec.facodec.facodec_trainer import FAcodecTrainer

from utils.util import load_config


def build_trainer(args, cfg):
supported_trainer = {
"FAcodec": FAcodecTrainer,
}

trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer


def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume_type",
type=str,
help="resume for continue to train, finetune for finetuning",
)
parser.add_argument(
"--checkpoint",
type=str,
help="checkpoint to resume",
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
args = parser.parse_args()
cfg = load_config(args.config)

# CUDA settings
cuda_relevant()

# Build trainer
trainer = build_trainer(args, cfg)

trainer.train_loop()


if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions config/facodec.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"exp_name": "facodec",
"model_type": "FAcodec",
"log_dir": "./runs/",
"log_interval": 10,
"save_interval": 1000,
"device": "cuda",
"epochs": 1000,
"batch_size": 4,
"batch_length": 100,
"max_len": 80,
"pretrained_model": "",
"load_only_params": false,
"F0_path": "modules/JDC/bst.t7",
"dataset": "dummy",
"preprocess_params": {
"sr": 24000,
"frame_rate": 80,
"duration_range": [1.0, 25.0],
"spect_params": {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
"n_mels": 80,
},
},
"train": {
"gradient_accumulation_step": 1,
"batch_size": 1,
"save_checkpoint_stride": [20],
"random_seed": 1234,
"max_epoch": -1,
"max_frame_len": 80,
"tracker": ["tensorboard"],
"run_eval": [false],
"sampler": {"holistic_shuffle": true, "drop_last": true},
"dataloader": {"num_worker": 0, "pin_memory": true},
},
"model_params": {
"causal": true,
"lstm": 2,
"norm_f0": true,
"use_gr_content_f0": false,
"use_gr_prosody_phone": false,
"use_gr_timbre_prosody": false,
"separate_prosody_encoder": true,
"n_c_codebooks": 2,
"timbre_norm": true,
"use_gr_content_global_f0": true,
"DAC": {
"encoder_dim": 64,
"encoder_rates": [2, 5, 5, 6],
"decoder_dim": 1536,
"decoder_rates": [6, 5, 5, 2],
"sr": 24000,
},
},
"loss_params": {
"base_lr": 0.0001,
"warmup_steps": 200,
"discriminator_iter_start": 2000,
"lambda_spk": 1.0,
"lambda_mel": 45,
"lambda_f0": 1.0,
"lambda_uv": 1.0,
},
}
51 changes: 51 additions & 0 deletions egs/codec/FAcodec/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# FAcodec

Pytorch implementation for the training of FAcodec, which was proposed in paper [NaturalSpeech 3: Zero-Shot Speech Synthesis
with Factorized Codec and Diffusion Models](https://arxiv.org/pdf/2403.03100)

A dedicated repository for the FAcodec model can also be find [here](https://github.com/Plachtaa/FAcodec).

This implementation made some key improvements to the training pipeline, so that the requirements of any form of annotations, including
transcripts, phoneme alignments, and speaker labels, are eliminated. All you need are simply raw speech files.
With the new training pipeline, it is possible to train the model on more languages with more diverse timbre distributions.
We release the code for training and inference, including a pretrained checkpoint on 50k hours speech data with over 1 million speakers.

## Model storage
We provide pretrained checkpoints on 50k hours speech data.

| Model type | Link |
|-------------------|----------------------------------------------------------------------------------------------------------------------------------------|
| FAcodec | [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-FAcodec-blue)](https://huggingface.co/Plachta/FAcodec) |

## Demo
Try our model on [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/Plachta/FAcodecV2)!

## Training
Prepare your data and put them under one folder, internal file structure does not matter.
Then, change the `dataset` in `./egs/codec/FAcodec/exp_custom_data.json` to the path of your data folder.
Finally, run the following command:
```bash
sh ./egs/codec/FAcodec/train.sh
```

## Inference
To reconstruct a speech file, run:
```bash
python ./bins/codec/inference.py --source <source_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path>
```
To use zero-shot voice conversion, run:
```bash
python ./bins/codec/inference.py --source <source_wav> --reference <reference_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path>
```

## Feature extraction
When running `./bins/codec/inference.py`, check the returned results of the `FAcodecInference` class: a tuple of `(quantized, codes)`
- `quantized` is the quantized representation of the input speech file.
- `quantized[0]` is the quantized representation of prosody
- `quantized[1]` is the quantized representation of content

- `codes` is the discrete code representation of the input speech file.
- `codes[0]` is the discrete code representation of prosody
- `codes[1]` is the discrete code representation of content

For the most clean content representation without any timbre, we suggest to use `codes[1][:, 0, :]`, which is the first layer of content codebooks.
80 changes: 80 additions & 0 deletions egs/codec/FAcodec/exp_custom_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"exp_name": "facodec",
"model_type": "FAcodec",

"log_dir": "./runs/",
"log_interval": 10,
"save_interval": 1000,
"device": "cuda",
"epochs": 1000,
"batch_size": 4,
"batch_length": 100,
"max_len": 80,
"pretrained_model": "",
"load_only_params": false,
"F0_path": "modules/JDC/bst.t7",
"dataset": "/path/to/dataset",
"preprocess_params": {
"sr": 24000,
"frame_rate": 80,
"duration_range": [1.0, 25.0],
"spect_params": {
"n_fft": 2048,
"win_length": 1200,
"hop_length": 300,
"n_mels": 80
}
},
"train": {
"gradient_accumulation_step": 1,
"batch_size": 1,
"save_checkpoint_stride": [
20
],
"random_seed": 1234,
"max_epoch": -1,
"max_frame_len": 80,
"tracker": [
"tensorboard"
],
"run_eval": [
false
],
"sampler": {
"holistic_shuffle": true,
"drop_last": true
},
"dataloader": {
"num_worker": 0,
"pin_memory": true
}
},
"model_params": {
"causal": true,
"lstm": 2,
"norm_f0": true,
"use_gr_content_f0": false,
"use_gr_prosody_phone": false,
"use_gr_timbre_prosody": false,
"separate_prosody_encoder": true,
"n_c_codebooks": 2,
"timbre_norm": true,
"use_gr_content_global_f0": true,
"DAC": {
"encoder_dim": 64,
"encoder_rates": [2, 5, 5, 6],
"decoder_dim": 1536,
"decoder_rates": [6, 5, 5, 2],
"sr": 24000
}
},
"loss_params": {
"base_lr": 0.0001,
"warmup_steps": 200,
"discriminator_iter_start": 2000,
"lambda_spk": 1.0,
"lambda_mel": 45,
"lambda_f0": 1.0,
"lambda_uv": 1.0
}
}
Loading

0 comments on commit 211e1d4

Please sign in to comment.