Skip to content

Commit

Permalink
Updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 13, 2024
1 parent 86e44dc commit 30ca712
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 127 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [published]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
17 changes: 2 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,12 @@ Pretrained model weights are available [on Hugging Face](https://huggingface.co/
import mlx.core as mx

from f5-tts-mlx.cfm import CFM
from f5-tts-mlx.dit import DiT

vocab = ...
f5tts = CFM(
transformer = DiT(
dim = 1024,
depth = 22,
heads = 16,
ff_mult = 2,
text_dim = 512,
conv_layers = 4,
text_num_embeds = ...
),
vocab_char_map=vocab
)
mx.eval(f5tts.parameters())
f5tts = CFM.from_pretrained("lucasnewman/f5-tts-mlx", vocab)
```

See `test_infer_single.py` for an example of generation.
See `examples/generate.py` for an example of generation.

## Appreciation

Expand Down
140 changes: 140 additions & 0 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import datetime

from pathlib import Path

import mlx.core as mx

import numpy as np

from f5_tts_mlx.cfm import CFM
from f5_tts_mlx.utils import convert_char_to_pinyin

from vocos_mlx import Vocos

import torch
import torchaudio

SAMPLE_RATE = 24_000
HOP_LENGTH = 256
FRAMES_PER_SEC = SAMPLE_RATE / HOP_LENGTH
TARGET_RMS = 0.1


def generate(
generation_text: str,
duration: float,
model_name: str = "lucasnewman/f5-tts-mlx",
vocab_path: str = "data/Emilia_ZH_EN_pinyin/vocab.txt",
ref_audio_path: str = "tests/test_en_1_ref_short.wav",
ref_audio_text: str = "Some call me nature, others call me mother nature.",
sway_sampling_coef: float = 0.0,
output_path: str = "output.wav",
):
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}

f5tts = CFM.from_pretrained(model_name, vocab)

# load reference audio
audio, sr = torchaudio.load(Path(ref_audio_path))
audio = mx.array(audio.numpy())
ref_audio_duration = audio.shape[1] / SAMPLE_RATE

rms = mx.sqrt(mx.mean(mx.square(audio)))
if rms < TARGET_RMS:
audio = audio * TARGET_RMS / rms

# generate the audio for the given text
text = convert_char_to_pinyin([ref_audio_text + " " + generation_text])

frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC)
print(f"Generating {frame_duration} total frames of audio...")

start_date = datetime.datetime.now()
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")

wave, _ = f5tts.sample(
audio,
text=text,
duration=frame_duration,
steps=32,
cfg_strength=1,
sway_sampling_coef=sway_sampling_coef,
seed=1234,
vocoder=vocos.decode,
)

# trim the reference audio
wave = wave[audio.shape[1] :]
generated_duration = len(wave) / SAMPLE_RATE
elapsed_time = datetime.datetime.now() - start_date

print(f"Generated {generated_duration:.2f} seconds of audio in {elapsed_time}.")

torchaudio.save(output_path, torch.Tensor(np.array(wave)).unsqueeze(0), SAMPLE_RATE)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate audio from text using f5-tts-mlx"
)

parser.add_argument(
"--model",
type=str,
default="lucasnewman/f5-tts-mlx",
help="Name of the model to use",
)
parser.add_argument(
"--text", type=str, required=True, help="Text to generate speech from"
)
parser.add_argument(
"--duration",
type=float,
required=True,
help="Duration of the generated audio in seconds",
)
parser.add_argument(
"--vocab",
type=str,
default="data/Emilia_ZH_EN_pinyin/vocab.txt",
help="Path to the vocab file",
)
parser.add_argument(
"--ref-audio",
type=str,
default="tests/test_en_1_ref_short.wav",
help="Path to the reference audio file",
)
parser.add_argument(
"--ref-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="Text spoken in the reference audio",
)
parser.add_argument(
"--output",
type=str,
default="output.wav",
help="Path to save the generated audio output",
)

parser.add_argument(
"--sway-coef",
type=float,
default="0.0",
help="Coefficient for sway sampling",
)

args = parser.parse_args()

generate(
generation_text=args.text,
duration=args.duration,
model_name=args.model,
vocab_path=args.vocab,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
sway_sampling_coef=args.sway_coef,
output_path=args.output,
)
83 changes: 83 additions & 0 deletions f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from __future__ import annotations
from pathlib import Path
from random import random
from typing import Callable

Expand All @@ -17,8 +18,21 @@
from einops.array_api import rearrange, reduce, repeat
import einx

from f5_tts_mlx.dit import DiT
from f5_tts_mlx.modules import MelSpec

from huggingface_hub import snapshot_download


def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors"],
)
)
return model_path / "model.safetensors"


def exists(v):
return v is not None
Expand Down Expand Up @@ -431,3 +445,72 @@ def fn(t, x):
mx.eval(out)

return out, trajectory

@classmethod
def from_pretrained(
cls,
hf_model_name_or_path: str,
vocab_char_map: dict[str, int],
) -> CFM:
path = fetch_from_hub(hf_model_name_or_path)
print(path)

if path is None:
raise ValueError(f"Could not find model {hf_model_name_or_path}")

f5tts = CFM(
transformer=DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=len(vocab_char_map) - 1,
),
vocab_char_map=vocab_char_map,
)

weights = mx.load(path.as_posix(), format="safetensors")
f5tts.load_weights(list(weights.items()))

# state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)["ema_model_state_dict"]

# # load weights

# new_state_dict = {}
# for k, v in state_dict.items():
# k = k.replace("ema_model.", "")
# v = mx.array(v.numpy())

# # rename layers
# if len(k) < 1 or "mel_spec." in k or k in ("initted", "step"):
# continue
# elif ".to_out" in k:
# k = k.replace(".to_out", ".to_out.layers")
# elif ".text_blocks" in k:
# k = k.replace(".text_blocks", ".text_blocks.layers")
# elif ".ff.ff.0.0" in k:
# k = k.replace(".ff.ff.0.0", ".ff.ff.layers.0.layers.0")
# elif ".ff.ff.2" in k:
# k = k.replace(".ff.ff.2", ".ff.ff.layers.2")
# elif ".time_mlp" in k:
# k = k.replace(".time_mlp", ".time_mlp.layers")
# elif ".conv1d" in k:
# k = k.replace(".conv1d", ".conv1d.layers")

# # reshape weights
# if ".dwconv.weight" in k:
# v = v.swapaxes(1, 2)
# elif ".conv1d.layers.0.weight" in k:
# v = v.swapaxes(1, 2)
# elif ".conv1d.layers.2.weight" in k:
# v = v.swapaxes(1, 2)

# new_state_dict[k] = v

# f5tts.load_weights(list(new_state_dict.items()))

mx.eval(f5tts.parameters())

return f5tts
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
[build-system]
requires = [
"einops",
"einx",
"jieba",
"huggingface_hub",
"mlx",
"numpy",
"setuptools",
"torch",
"torchaudio",
"vocos-mlx"
]
build-backend = "setuptools.build_meta"

[project]
name = "f5-tts-mlx"
version = "0.0.1"
version = "0.0.2"
authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}]
license = {text = "MIT"}
description = "F5-TTS - MLX"
Expand Down Expand Up @@ -40,3 +43,6 @@ Homepage = "https://github.com/lucasnewman/f5-tts-mlx"

[tool.setuptools]
packages = ["f5_tts_mlx"]

[tool.setuptools.package-data]
f5_tts_mlx = ["data/**/*.txt", "assets/mel_filters.npz"]
Loading

0 comments on commit 30ca712

Please sign in to comment.