Skip to content

Commit

Permalink
Add notes and citations.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 13, 2024
1 parent a3c3d50 commit 551f575
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 29 deletions.
78 changes: 77 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,77 @@
# f5-tts-mlx
![F5 TTS diagram](f5tts.jpg)

# F5 TTS — MLX

Implementation of [F5-TTS](https://arxiv.org/abs/2410.06885), with the [MLX](https://github.com/ml-explore/mlx) framework.

F5 TTS is a non-autoregressive, zero-shot text-to-speech system using a flow-matching mel spectrogram generator with a diffusion transformer (DiT).

F5 is an evolution of [E2 TTS](https://arxiv.org/abs/2406.18009v2) and improves performance with ConvNeXT v2 blocks for the learned text alignment.

This repository is based on the original Pytorch implementation available [here](https://github.com/SWivid/F5-TTS).


## Installation

```bash
pip install f5-tts-mlx
```

Pretrained model weights are available [on Hugging Face](https://huggingface.co/SWivid/F5-TTS).

## Usage

```python
import mlx.core as mx

from f5-tts-mlx.cfm import CFM
from f5-tts-mlx.dit import DiT

vocab = ...
f5tts = CFM(
transformer = DiT(
dim = 1024,
depth = 22,
heads = 16,
ff_mult = 2,
text_dim = 512,
conv_layers = 4,
text_num_embeds = ...
),
vocab_char_map=vocab
)
mx.eval(f5tts.parameters())
```

See `test_infer_single.py` for an example of generation.

## Appreciation

[Yushen Chen](https://github.com/SWivid) for the original Pytorch implementation of F5 TTS and pretrained model.

[Phil Wang](https://github.com/lucidrains) for the E2 TTS implementation that this model is based on.

## Citations

```bibtex
@article{chen-etal-2024-f5tts,
title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
journal={arXiv preprint arXiv:2410.06885},
year={2024},
}
```

```bibtex
@inproceedings{Eskimez2024E2TE,
title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS},
author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:270738197}
}
```

## License

The code in this repository is released under the MIT license as found in the
[LICENSE](LICENSE) file.
Binary file added f5tts.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 47 additions & 28 deletions test_infer_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,58 @@
vocab_path = Path("data/Emilia_ZH_EN_pinyin/vocab.txt")
vocab = {v: i for i, v in enumerate(vocab_path.read_text().split("\n"))}

f5tts = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4, text_num_embeds = len(vocab) - 1)
cfm = CFM(transformer=f5tts, vocab_char_map=vocab)

ckpt_path = Path("model_1200000.pt")
state_dict = torch.load(ckpt_path, map_location='cpu', weights_only=True)['ema_model_state_dict']
f5tts = CFM(
transformer=DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=len(vocab) - 1,
),
vocab_char_map=vocab,
)

ckpt_path = Path("./ckpts/F5TTS_Base/model_1200000.pt")

state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["ema_model_state_dict"]

# load weights

new_state_dict = {}
for k, v in state_dict.items():
k = k.replace('ema_model.', '')
k = k.replace("ema_model.", "")
v = mx.array(v.numpy())

# rename layers
if len(k) < 1 or 'mel_spec.' in k or k in ('initted', 'step'):
if len(k) < 1 or "mel_spec." in k or k in ("initted", "step"):
continue
elif '.to_out' in k:
k = k.replace('.to_out', '.to_out.layers')
elif '.text_blocks' in k:
k = k.replace('.text_blocks', '.text_blocks.layers')
elif '.ff.ff.0.0' in k:
k = k.replace('.ff.ff.0.0', '.ff.ff.layers.0.layers.0')
elif '.ff.ff.2' in k:
k = k.replace('.ff.ff.2', '.ff.ff.layers.2')
elif '.time_mlp' in k:
k = k.replace('.time_mlp', '.time_mlp.layers')
elif '.conv1d' in k:
k = k.replace('.conv1d', '.conv1d.layers')
elif ".to_out" in k:
k = k.replace(".to_out", ".to_out.layers")
elif ".text_blocks" in k:
k = k.replace(".text_blocks", ".text_blocks.layers")
elif ".ff.ff.0.0" in k:
k = k.replace(".ff.ff.0.0", ".ff.ff.layers.0.layers.0")
elif ".ff.ff.2" in k:
k = k.replace(".ff.ff.2", ".ff.ff.layers.2")
elif ".time_mlp" in k:
k = k.replace(".time_mlp", ".time_mlp.layers")
elif ".conv1d" in k:
k = k.replace(".conv1d", ".conv1d.layers")

# reshape weights
if '.dwconv.weight' in k:
if ".dwconv.weight" in k:
v = v.swapaxes(1, 2)
elif '.conv1d.layers.0.weight' in k:
elif ".conv1d.layers.0.weight" in k:
v = v.swapaxes(1, 2)
elif '.conv1d.layers.2.weight' in k:
elif ".conv1d.layers.2.weight" in k:
v = v.swapaxes(1, 2)

new_state_dict[k] = v

cfm.load_weights(list(new_state_dict.items()))
cfm.eval()
f5tts.load_weights(list(new_state_dict.items()))
mx.eval(f5tts.parameters())

target_rms = 0.1
path = Path("tests/test_en_1_ref_short")
Expand All @@ -82,7 +93,15 @@

start_date = datetime.datetime.now()

mel, _ = cfm.sample(audio, text=text, duration=duration, steps=32, cfg_strength=1, sway_sampling_coef=None, seed=1234)
mel, _ = f5tts.sample(
audio,
text=text,
duration=duration,
steps=32,
cfg_strength=1,
sway_sampling_coef=None,
seed=1234,
)

vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
wave = vocos.decode(mel)
Expand Down

0 comments on commit 551f575

Please sign in to comment.