From 4e98a448b0e85c3174b8b4d805d6815027169e3a Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 08:52:12 -0700 Subject: [PATCH 01/18] Add duration predictor model and training loop. --- f5_tts_mlx/data.py | 223 ++++++++++++++++++++++++++++++++++++++ f5_tts_mlx/duration.py | 238 +++++++++++++++++++++++++++++++++++++++++ f5_tts_mlx/trainer.py | 173 ++++++++++++++++++++++++++++++ 3 files changed, 634 insertions(+) create mode 100644 f5_tts_mlx/data.py create mode 100644 f5_tts_mlx/duration.py create mode 100644 f5_tts_mlx/trainer.py diff --git a/f5_tts_mlx/data.py b/f5_tts_mlx/data.py new file mode 100644 index 0000000..c63712e --- /dev/null +++ b/f5_tts_mlx/data.py @@ -0,0 +1,223 @@ +from functools import partial +import hashlib +from pathlib import Path + +import mlx.core as mx +import mlx.data as dx +import numpy as np + +from einops.array_api import rearrange + +from mlx.data.datasets.common import ( + CACHE_DIR, + ensure_exists, + urlretrieve_with_progress, + file_digest, + gzip_decompress, +) + +from f5_tts_mlx.modules import log_mel_spectrogram + +# utilities + + +def files_with_extensions(dir: Path, extensions: list = ["wav"]): + files = [] + for ext in extensions: + files.extend(list(dir.rglob(f"*.{ext}"))) + files = sorted(files) + + return [{"file": f.as_posix().encode("utf-8")} for f in files] + + +# transforms + + +def _load_transcript_file(sample): + audio_file = Path(bytes(sample["file"]).decode("utf-8")) + transcript_file = audio_file.with_suffix(".normalized.txt") + sample["transcript_file"] = transcript_file.as_posix().encode("utf-8") + return sample + + +def _load_transcript(sample): + audio_file = Path(bytes(sample["file"]).decode("utf-8")) + transcript_file = audio_file.with_suffix(".normalized.txt") + if not transcript_file.exists(): + return dict() + + transcript = np.array( + list(transcript_file.read_text().strip().encode("utf-8")), dtype=np.int8 + ) + sample["transcript"] = transcript + return sample + + +def _load_cached_mel_spec(sample, max_duration=5): + audio_file = Path(bytes(sample["file"]).decode("utf-8")) + mel_file = audio_file.with_suffix(".mel.npy.npz") + mel_spec = mx.load(mel_file.as_posix())["arr_0"] + mel_len = mel_spec.shape[1] + + if mel_len > int(max_duration * 93.75): + return dict() + + sample["mel_spec"] = mel_spec + sample["mel_len"] = mel_len + del sample["file"] + return sample + + +def _load_audio_file(sample): + audio_file = Path(bytes(sample["file"]).decode("utf-8")) + audio = np.array(list(audio_file.read_bytes()), dtype=np.int8) + sample["audio"] = audio + return sample + + +def _to_mel_spec(sample): + audio = rearrange(mx.array(sample["audio"]), "t 1 -> t") + mel_spec = log_mel_spectrogram(audio) + sample["mel_spec"] = mel_spec + sample["mel_len"] = mel_spec.shape[1] + return sample + + +def _with_max_duration(sample, sample_rate=24_000, max_duration=30): + audio_duration = sample["audio"].shape[0] / sample_rate + if audio_duration > max_duration: + return dict() + return sample + + +# dataset loading + +SPLITS = { + "dev-clean": ( + "https://www.openslr.org/resources/141/dev_clean.tar.gz", + "2c1f5312914890634cc2d15783032ff3", + ), + "dev-other": ( + "https://www.openslr.org/resources/141/dev_other.tar.gz", + "62d3a80ad8a282b6f31b3904f0507e4f", + ), + "test-clean": ( + "https://www.openslr.org/resources/141/test_clean.tar.gz", + "4d373d453eb96c0691e598061bbafab7", + ), + "test-other": ( + "https://www.openslr.org/resources/141/test_other.tar.gz", + "dbc0959d8bdb6d52200595cabc9995ae", + ), + "train-clean-100": ( + "https://www.openslr.org/resources/141/train_clean_100.tar.gz", + "6df668d8f5f33e70876bfa33862ad02b", + ), + "train-clean-360": ( + "https://www.openslr.org/resources/141/train_clean_360.tar.gz", + "382eb3e64394b3da6a559f864339b22c", + ), + "train-other-500": ( + "https://www.openslr.org/resources/141/train_other_500.tar.gz", + "a37a8e9f4fe79d20601639bf23d1add8", + ), +} + + +def load_libritts_r_tarfile( + root=None, split="dev-clean", quiet=False, validate_download=True +): + """Fetch the libritts_r TAR archive and return the path to it for manual processing. + + Args: + root (Path or str, optional): The The directory to load/save the data. If + none is given the ``~/.cache/mlx.data/libritts_r`` is used. + split (str): The split to use. It should be one of dev-clean, + dev-other, test-clean, test-other, train-clean-100, + train-clean-360, train-other-500 . + quiet (bool): If true do not show download (and possibly decompression) + progress. + """ + if split not in SPLITS: + raise ValueError( + f"Unknown libritts_r split '{split}'. It should be one of [{', '.join(SPLITS.keys())}]" + ) + + if root is None: + root = CACHE_DIR / "libritts_r" + else: + root = Path(root) + ensure_exists(root) + + url, target_hash = SPLITS[split] + filename = Path(url).name + target_compressed = root / filename + target = root / filename.replace(".gz", "") + + if not target.is_file(): + if not target_compressed.is_file(): + urlretrieve_with_progress(url, target_compressed, quiet=quiet) + if validate_download: + h = file_digest(target_compressed, hashlib.md5(), quiet=quiet) + if h.hexdigest() != target_hash: + raise RuntimeError( + f"[libritts_r] File download corrupted md5sums don't match. Please manually delete {str(target_compressed)}." + ) + + gzip_decompress(target_compressed, target, quiet=quiet) + target_compressed.unlink() + + return target + + +def load_libritts_r( + root=None, split="dev-clean", quiet=False, validate_download=True, max_duration=30 +): + """Load the libritts_r dataset directly from the TAR archive. + + Args: + root (Path or str, optional): The The directory to load/save the data. If + none is given the ``~/.cache/mlx.data/libritts_r`` is used. + split (str): The split to use. It should be one of dev-clean, + dev-other, test-clean, test-other, train-clean-100, + train-clean-360, train-other-500 . + quiet (bool): If true do not show download (and possibly decompression) + progress. + """ + + target = load_libritts_r_tarfile( + root=root, split=split, quiet=quiet, validate_download=validate_download + ) + target = str(target) + + dset = ( + dx.files_from_tar(target) + .to_stream() + .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".wav") else dict()) + .sample_transform(_load_transcript_file) + .read_from_tar(target, "transcript_file", "transcript") + .read_from_tar(target, "file", "audio") + .load_audio("audio", from_memory=True) + .sample_transform(partial(_with_max_duration, max_duration=max_duration)) + .sample_transform(_to_mel_spec) + ) + + return dset + + +def load_dir(dir=None, max_duration=30): + path = Path(dir).expanduser() + + files = files_with_extensions(path) + print(f"Found {len(files)} files at {path}") + + dset = ( + dx.buffer_from_vector(files) + .to_stream() + .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".wav") else dict()) + .sample_transform(_load_transcript) + .sample_transform(partial(_load_cached_mel_spec, max_duration=max_duration)) + .pad_to_multiple("mel_spec", dim=1, pad_multiple=512, pad_value=0.0) + ) + + return dset diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py new file mode 100644 index 0000000..0787ad6 --- /dev/null +++ b/f5_tts_mlx/duration.py @@ -0,0 +1,238 @@ +from __future__ import annotations +from pathlib import Path +from random import random +from typing import Callable + +import mlx.core as mx +import mlx.nn as nn + +from einops.array_api import rearrange, reduce, repeat +import einx + +from f5_tts_mlx.cfm import ( + list_str_to_idx, + list_str_to_tensor, + lens_to_mask, + maybe_masked_mean, +) +from f5_tts_mlx.dit import DiT, TextEmbedding, TimestepEmbedding, ConvPositionEmbedding + +from f5_tts_mlx.modules import ( + MelSpec, + RotaryEmbedding, + DiTBlock, + AdaLayerNormZero_Final, +) + +SAMPLE_RATE = 24_000 +HOP_LENGTH = 256 +SAMPLES_PER_SECOND = SAMPLE_RATE / HOP_LENGTH + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +class Rearrange(nn.Module): + def __init__(self, pattern: str): + super().__init__() + self.pattern = pattern + + def __call__(self, x: mx.array) -> mx.array: + return rearrange(x, self.pattern) + + +class DurationInputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def __call__( + self, + x: float["b n d"], + text_embed: float["b n d"], + ): + x = self.proj(mx.concatenate((x, text_embed), axis=-1)) + x = self.conv_pos_embed(x) + x + return x + + +class DurationTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + long_skip_connection=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, conv_layers=conv_layers + ) + self.input_embed = DurationInputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) + for _ in range(depth) + ] + self.long_skip_connection = ( + nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + ) + + self.norm_out = nn.RMSNorm(dim) # final modulation + + def __call__( + self, + x: float["b n d"], # nosied input audio + text: int["b nt"], # text + mask: bool["b n"] | None = None, + ): + batch, seq_len = x.shape[0], x.shape[1] + + time = mx.ones((batch,), dtype=mx.float32) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + + text_embed = self.text_embed(text, seq_len) + + x = self.input_embed(x, text_embed) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(mx.concatenate((x, residual), axis=-1)) + + x = self.norm_out(x) + + return x + + +class DurationPredictor(nn.Module): + def __init__( + self, + transformer: DiT, + num_channels=None, + mel_spec_kwargs: dict = dict(), + vocab_char_map: dict[str, int] | None = None, + ): + super().__init__() + + # mel spec + + self.mel_spec = MelSpec(**mel_spec_kwargs) + num_channels = default(num_channels, self.mel_spec.n_mels) + self.num_channels = num_channels + + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + self.dim = dim + + # self.proj_in = nn.Linear(self.num_channels, self.dim) + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + # to prediction + + self.to_pred = nn.Sequential( + nn.Linear(dim, 1, bias=False), nn.Softplus(), Rearrange("... 1 -> ...") + ) + + def __call__( + self, + inp: mx.array["b n d"] | mx.array["b nw"], # mel or raw wave + text: mx.array | list[str], + *, + lens: mx.array["b"] | None = None, + return_loss=False, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = rearrange(inp, "b d n -> b n d") + assert inp.shape[-1] == self.num_channels + + batch, seq_len = inp.shape[:2] + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map) + else: + text = list_str_to_tensor(text) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = mx.full((batch,), seq_len) + + mask = lens_to_mask(lens, length=seq_len) + + # if returning a loss, mask out randomly from an index and have it predict the duration + + if return_loss: + rand_frac_index = mx.random.uniform(0, 1, (batch,)) + rand_index = (rand_frac_index * lens).astype(mx.int32) + + seq = mx.arange(seq_len) + mask &= einx.less("n, b -> b n", seq, rand_index) + + # attending + + inp = mx.where( + repeat(mask, "b n -> b n d", d=self.num_channels), inp, mx.zeros_like(inp) + ) + + x = self.transformer(inp, text=text) + + x = maybe_masked_mean(x, mask) + + pred = self.to_pred(x) + + # return the prediction if not returning loss + + if not return_loss: + return pred + + # loss + + duration = lens.astype(mx.float32) / SAMPLES_PER_SECOND + + return nn.losses.mse_loss(pred, duration) diff --git a/f5_tts_mlx/trainer.py b/f5_tts_mlx/trainer.py new file mode 100644 index 0000000..8208bfe --- /dev/null +++ b/f5_tts_mlx/trainer.py @@ -0,0 +1,173 @@ +from __future__ import annotations +import datetime + +from einops.array_api import rearrange + +import mlx.core as mx +import mlx.nn as nn +from mlx.optimizers import ( + AdamW, + linear_schedule, + cosine_decay, + join_schedules, + clip_grad_norm, +) +from mlx.utils import tree_flatten + +from f5_tts_mlx.cfm import F5TTS +from f5_tts_mlx.duration import DurationPredictor +from f5_tts_mlx.modules import MelSpec + +import wandb + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +# trainer + + +class DurationTrainer: + def __init__( + self, + model: DurationPredictor, + num_warmup_steps=1000, + max_grad_norm=1.0, + sample_rate=24_000, + log_with_wandb=False, + ): + self.model = model + self.num_warmup_steps = num_warmup_steps + self.mel_spectrogram = MelSpec(sample_rate=sample_rate) + self.max_grad_norm = max_grad_norm + self.log_with_wandb = log_with_wandb + + def save_checkpoint(self, step, finetune=False): + mx.save_safetensors( + f"f5tts_duration_{step}", + dict(tree_flatten(self.model.trainable_parameters())), + ) + + def load_checkpoint(self, step): + params = mx.load(f"f5tts_duration_{step}.saftensors") + self.model.load_weights(params) + self.model.eval() + + def train( + self, + train_dataset, + learning_rate=1e-4, + weight_decay=1e-2, + total_steps=100_000, + batch_size=8, + log_every=10, + save_every=1000, + checkpoint: int | None = None, + ): + if self.log_with_wandb: + wandb.init( + project="f5tts_duration", + config=dict( + learning_rate=learning_rate, + total_steps=total_steps, + batch_size=batch_size, + ), + ) + + decay_steps = total_steps - self.num_warmup_steps + + warmup_scheduler = linear_schedule( + init=1e-8, + end=learning_rate, + steps=self.num_warmup_steps, + ) + decay_scheduler = cosine_decay(init=learning_rate, decay_steps=decay_steps) + scheduler = join_schedules( + schedules=[warmup_scheduler, decay_scheduler], + boundaries=[self.num_warmup_steps], + ) + self.optimizer = AdamW(learning_rate=scheduler, weight_decay=weight_decay) + + if checkpoint is not None: + self.load_checkpoint(checkpoint) + start_step = checkpoint + else: + start_step = 0 + + global_step = start_step + + def loss_fn(model: F5TTS, mel_spec, text, lens): + loss = model(mel_spec, text=text, lens=lens, return_loss=True) + return loss + + # state = [self.model.state, self.optimizer.state, mx.random.state] + + # @partial(mx.compile, inputs=state, outputs=state) + def train_step(mel_spec, text_inputs, mel_lens): + loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) + loss, grads = loss_and_grad_fn( + self.model, mel_spec, text=text_inputs, lens=mel_lens + ) + + if self.max_grad_norm > 0: + grads, _ = clip_grad_norm(grads, max_norm=self.max_grad_norm) + + self.optimizer.update(self.model, grads) + + return loss + + training_start_date = datetime.datetime.now() + log_start_date = datetime.datetime.now() + + batched_dataset = ( + train_dataset.repeat(1_000_000) # repeat indefinitely + .shuffle(1000) + .prefetch(prefetch_size=batch_size, num_threads=4) + .batch(batch_size) + ) + + for batch in batched_dataset: + effective_batch_size = batch["transcript"].shape[0] + text_inputs = [ + bytes(batch["transcript"][i]).decode("utf-8") + for i in range(effective_batch_size) + ] + + mel_spec = rearrange(mx.array(batch["mel_spec"]), "b 1 n c -> b n c") + mel_lens = mx.array(batch["mel_len"], dtype=mx.int32) + + loss = train_step(mel_spec, text_inputs, mel_lens) + # mx.eval(state) + mx.eval(self.model.parameters(), self.optimizer.state) + + if self.log_with_wandb: + wandb.log( + {"loss": loss.item(), "lr": self.optimizer.learning_rate.item()}, + step=global_step, + ) + + if global_step > 0 and global_step % log_every == 0: + elapsed_time = datetime.datetime.now() - log_start_date + log_start_date = datetime.datetime.now() + + print( + f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(elapsed_time.seconds / log_every):.2f}" + ) + + global_step += 1 + + if global_step % save_every == 0: + self.save_checkpoint(global_step) + + if global_step >= total_steps: + break + + if self.log_with_wandb: + wandb.finish() + + print(f"Training complete in {datetime.datetime.now() - training_start_date}") From 76b2bad2f4bed06fa76e92f0d41d21f8b662804a Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 12:09:43 -0700 Subject: [PATCH 02/18] Update README.md --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index f26edab..52ca308 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,15 @@ python -m f5_tts_mlx.generate \ --text "The quick brown fox jumped over the lazy dog." ``` +If you want to use your own reference audio sample, make sure it's encoded at 24kHz and use the --ref-audio and --ref-text options: + +```bash +python -m f5_tts_mlx.generate \ +--text "The quick brown fox jumped over the lazy dog." +--ref-audio /path/to/audio.wav +--ref-text "This is the caption for the reference audio." +``` + See [examples/generate.py](./examples) for more options. — From 063de6cb8d5ea68536bf0108712387b3ee4bd146 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 12:12:51 -0700 Subject: [PATCH 03/18] Update README.md --- examples/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index e5eb565..90de076 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,13 +15,21 @@ string Provide the text that you want to generate. +## Optional Parameters + `-–duration` float Specify the length of the generated audio in seconds. -## Optional Parameters + +`-–speed` + +float, default: 1.0 + +Speaking speed modifier, used when an exact duration is not specified. + `--model` @@ -29,12 +37,14 @@ string, default: "lucasnewman/f5-tts-mlx" Specify a custom model to use for generation. If not provided, the script will use the default model. + `--ref-audio` string, default: "tests/test_en_1_ref_short.wav" Provide a reference audio file path to help guide the generation. + `–-ref-text` string, default: "Some call me nature, others call me mother nature." From c0ba0e1e9946bda6602c87f329f5e0538d9b98c0 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 12:13:30 -0700 Subject: [PATCH 04/18] Update README.md --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 90de076..c7bdab1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -63,7 +63,7 @@ Specify the output path where the generated audio will be saved. If not specifie float, default: 0.0 -Set the sway sampling coefficient. The best values according to the paper are in the range of [-0.4...0.4]. +Set the sway sampling coefficient. The best values according to the paper are in the range of [-1.0...1.0]. `-–seed` From 7d5832ef589f9b3511f61647f5ae20e73f4c2fa5 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 12:33:22 -0700 Subject: [PATCH 05/18] No need for time embedding in the duration model. --- f5_tts_mlx/cfm.py | 6 ++-- f5_tts_mlx/duration.py | 70 +++++++++++++++++++++++++----------------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 4e8a4b0..bd32c56 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -138,14 +138,16 @@ def list_str_to_tensor(text: list[str], padding_value=-1) -> mx.array: # Int['b def list_str_to_idx( - text: list[str] | list[list[str]], + text: list[str], vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ) -> mx.array: # Int['b nt']: list_idx_tensors = [ [vocab_char_map.get(c, 0) for c in t] for t in text ] # pinyin or char style - text = pad_sequence(mx.array(list_idx_tensors), padding_value=-1) + + list_idx_tensors = [mx.array(t) for t in list_idx_tensors] + text = pad_sequence(list_idx_tensors, padding_value=padding_value) return text diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index 0787ad6..f63ebf9 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -1,12 +1,9 @@ from __future__ import annotations -from pathlib import Path -from random import random -from typing import Callable import mlx.core as mx import mlx.nn as nn -from einops.array_api import rearrange, reduce, repeat +from einops.array_api import rearrange, repeat import einx from f5_tts_mlx.cfm import ( @@ -15,13 +12,13 @@ lens_to_mask, maybe_masked_mean, ) -from f5_tts_mlx.dit import DiT, TextEmbedding, TimestepEmbedding, ConvPositionEmbedding +from f5_tts_mlx.dit import DiT, TextEmbedding, ConvPositionEmbedding from f5_tts_mlx.modules import ( + Attention, + FeedForward, MelSpec, - RotaryEmbedding, - DiTBlock, - AdaLayerNormZero_Final, + RotaryEmbedding ) SAMPLE_RATE = 24_000 @@ -60,7 +57,40 @@ def __call__( x = self.proj(mx.concatenate((x, text_embed), axis=-1)) x = self.conv_pos_embed(x) + x return x + +class DurationTransformerBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = nn.RMSNorm(dim) + self.attn = Attention( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, affine=False, eps=1e-6) + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def __call__( + self, x, mask=None, rope=None + ): # x: noised input + norm = self.attn_norm(x) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + # process attention output for input x + x = x * attn_output + + norm = self.ff_norm(x) + ff_output = self.ff(norm) + x = x * ff_output + + return x class DurationTransformer(nn.Module): def __init__( @@ -75,12 +105,10 @@ def __init__( mel_dim=100, text_num_embeds=256, text_dim=None, - conv_layers=0, - long_skip_connection=False, + conv_layers=0 ): super().__init__() - self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( @@ -94,7 +122,7 @@ def __init__( self.depth = depth self.transformer_blocks = [ - DiTBlock( + DurationTransformerBlock( dim=dim, heads=heads, dim_head=dim_head, @@ -103,9 +131,6 @@ def __init__( ) for _ in range(depth) ] - self.long_skip_connection = ( - nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - ) self.norm_out = nn.RMSNorm(dim) # final modulation @@ -115,12 +140,7 @@ def __call__( text: int["b nt"], # text mask: bool["b n"] | None = None, ): - batch, seq_len = x.shape[0], x.shape[1] - - time = mx.ones((batch,), dtype=mx.float32) - - # t: conditioning time, c: context (text + masked cond audio), x: noised input audio - t = self.time_embed(time) + seq_len = x.shape[1] text_embed = self.text_embed(text, seq_len) @@ -128,14 +148,8 @@ def __call__( rope = self.rotary_embed.forward_from_seq_len(seq_len) - if self.long_skip_connection is not None: - residual = x - for block in self.transformer_blocks: - x = block(x, t, mask=mask, rope=rope) - - if self.long_skip_connection is not None: - x = self.long_skip_connection(mx.concatenate((x, residual), axis=-1)) + x = block(x, mask=mask, rope=rope) x = self.norm_out(x) From 26f70239990a5ac780a5781bf62c7a8a6afbcdf6 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 15 Oct 2024 13:32:33 -0700 Subject: [PATCH 06/18] Keep the enhanced conditioning on the duration predictor, it seems to help. --- f5_tts_mlx/data.py | 2 +- f5_tts_mlx/duration.py | 53 ++++++++---------------------------------- 2 files changed, 11 insertions(+), 44 deletions(-) diff --git a/f5_tts_mlx/data.py b/f5_tts_mlx/data.py index c63712e..fc8aabe 100644 --- a/f5_tts_mlx/data.py +++ b/f5_tts_mlx/data.py @@ -217,7 +217,7 @@ def load_dir(dir=None, max_duration=30): .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".wav") else dict()) .sample_transform(_load_transcript) .sample_transform(partial(_load_cached_mel_spec, max_duration=max_duration)) - .pad_to_multiple("mel_spec", dim=1, pad_multiple=512, pad_value=0.0) + .pad_to_multiple("mel_spec", dim=1, pad_multiple=1024, pad_value=0.0) ) return dset diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index f63ebf9..ed2b695 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -12,13 +12,12 @@ lens_to_mask, maybe_masked_mean, ) -from f5_tts_mlx.dit import DiT, TextEmbedding, ConvPositionEmbedding +from f5_tts_mlx.dit import DiT, TextEmbedding, TimestepEmbedding, ConvPositionEmbedding from f5_tts_mlx.modules import ( - Attention, - FeedForward, MelSpec, - RotaryEmbedding + RotaryEmbedding, + DiTBlock, ) SAMPLE_RATE = 24_000 @@ -57,40 +56,7 @@ def __call__( x = self.proj(mx.concatenate((x, text_embed), axis=-1)) x = self.conv_pos_embed(x) + x return x - -class DurationTransformerBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): - super().__init__() - - self.attn_norm = nn.RMSNorm(dim) - self.attn = Attention( - dim=dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - ) - - self.ff_norm = nn.LayerNorm(dim, affine=False, eps=1e-6) - self.ff = FeedForward( - dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" - ) - - def __call__( - self, x, mask=None, rope=None - ): # x: noised input - norm = self.attn_norm(x) - # attention - attn_output = self.attn(x=norm, mask=mask, rope=rope) - - # process attention output for input x - x = x * attn_output - - norm = self.ff_norm(x) - ff_output = self.ff(norm) - x = x * ff_output - - return x class DurationTransformer(nn.Module): def __init__( @@ -105,10 +71,11 @@ def __init__( mel_dim=100, text_num_embeds=256, text_dim=None, - conv_layers=0 + conv_layers=0, ): super().__init__() + self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim self.text_embed = TextEmbedding( @@ -122,7 +89,7 @@ def __init__( self.depth = depth self.transformer_blocks = [ - DurationTransformerBlock( + DiTBlock( dim=dim, heads=heads, dim_head=dim_head, @@ -140,7 +107,9 @@ def __call__( text: int["b nt"], # text mask: bool["b n"] | None = None, ): - seq_len = x.shape[1] + batch, seq_len = x.shape[0], x.shape[1] + + t = self.time_embed(mx.ones((batch,), dtype=mx.float32)) text_embed = self.text_embed(text, seq_len) @@ -149,7 +118,7 @@ def __call__( rope = self.rotary_embed.forward_from_seq_len(seq_len) for block in self.transformer_blocks: - x = block(x, mask=mask, rope=rope) + x = block(x, t, mask=mask, rope=rope) x = self.norm_out(x) @@ -178,8 +147,6 @@ def __init__( self.dim = dim - # self.proj_in = nn.Linear(self.num_channels, self.dim) - # vocab map for tokenization self.vocab_char_map = vocab_char_map From fe153131664bd42bd3c265181e31390f7e249e67 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 16 Oct 2024 08:46:08 -0700 Subject: [PATCH 07/18] Update README.md --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 52ca308..3ffbf4a 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ python -m f5_tts_mlx.generate \ --text "The quick brown fox jumped over the lazy dog." ``` -If you want to use your own reference audio sample, make sure it's encoded at 24kHz and use the --ref-audio and --ref-text options: +If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds: ```bash python -m f5_tts_mlx.generate \ @@ -32,6 +32,12 @@ python -m f5_tts_mlx.generate \ --ref-text "This is the caption for the reference audio." ``` +You can convert an audio file to the correct format with ffmpeg like this: + +```bash +ffmpeg -i /path/to/audio.wav -ac 1 -ar 24000 -sample_fmt s16 -t 10 /path/to/output_audio.wav +``` + See [examples/generate.py](./examples) for more options. — From 80a19bc7fe7a2b631cf1981cbcdb1e74df962823 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 16 Oct 2024 08:46:58 -0700 Subject: [PATCH 08/18] Update README.md --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 3ffbf4a..5f87d61 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,7 @@ pip install f5-tts-mlx ## Usage ```bash -python -m f5_tts_mlx.generate \ ---text "The quick brown fox jumped over the lazy dog." +python -m f5_tts_mlx.generate --text "The quick brown fox jumped over the lazy dog." ``` If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds: From b2138b945e242dd3370ce1f6dddb050b7ada563e Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 16 Oct 2024 10:05:48 -0700 Subject: [PATCH 09/18] Add a sampler for the duration predictor. --- f5_tts_mlx/duration.py | 49 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index ed2b695..8b5f286 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -108,11 +108,11 @@ def __call__( mask: bool["b n"] | None = None, ): batch, seq_len = x.shape[0], x.shape[1] - + t = self.time_embed(mx.ones((batch,), dtype=mx.float32)) - + text_embed = self.text_embed(text, seq_len) - + x = self.input_embed(x, text_embed) rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -128,7 +128,7 @@ def __call__( class DurationPredictor(nn.Module): def __init__( self, - transformer: DiT, + transformer: DurationTransformer, num_channels=None, mel_spec_kwargs: dict = dict(), vocab_char_map: dict[str, int] | None = None, @@ -200,11 +200,11 @@ def __call__( inp = mx.where( repeat(mask, "b n -> b n d", d=self.num_channels), inp, mx.zeros_like(inp) ) - + x = self.transformer(inp, text=text) x = maybe_masked_mean(x, mask) - + pred = self.to_pred(x) # return the prediction if not returning loss @@ -217,3 +217,40 @@ def __call__( duration = lens.astype(mx.float32) / SAMPLES_PER_SECOND return nn.losses.mse_loss(pred, duration) + + def sample( + self, + cond: mx.array["b n d"] | mx.array["b nw"], + text: mx.array["b nt"] | list[str], + *, + lens: mx.array["b"] | None = None, + max_duration=4096, + ) -> tuple[mx.array, mx.array]: + self.eval() + + # raw wave + + if cond.ndim == 2: + cond = rearrange(cond, "1 n -> n") + cond = self.mel_spec(cond) + assert cond.shape[-1] == self.num_channels + + batch, cond_seq_len, dtype = *cond.shape[:2], cond.dtype + if not exists(lens): + lens = mx.full((batch,), cond_seq_len, dtype=dtype) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map) + else: + text = list_str_to_tensor(text) + assert text.shape[0] == batch + + if exists(text): + text_lens = (text != -1).sum(axis=-1) + lens = mx.maximum(text_lens, lens) + + pred = self.transformer(cond, text=text) + return mx.minimum(max_duration / SAMPLES_PER_SECOND, pred) From 3465f85bc59ece360db6cdc0c670972b618e3702 Mon Sep 17 00:00:00 2001 From: felix_red_panda Date: Fri, 18 Oct 2024 01:49:31 +0200 Subject: [PATCH 10/18] add requirements.txt --- requirements.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7722007 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +einops +einx +jieba +huggingface_hub +mlx +numpy +pypinyin +setuptools +soundfile +vocos-mlx \ No newline at end of file From b2b070c86b6af44990badcd01ec58f581dff8478 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 08:43:57 -0700 Subject: [PATCH 11/18] Add ODE steps as a parameter. --- examples/README.md | 27 ++++++++++++++++++++------- examples/generate.py | 10 +++++++++- f5_tts_mlx/cfm.py | 2 +- f5_tts_mlx/generate.py | 10 +++++++++- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/examples/README.md b/examples/README.md index c7bdab1..19284a6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,14 +17,14 @@ Provide the text that you want to generate. ## Optional Parameters -`-–duration` +`--duration` float Specify the length of the generated audio in seconds. -`-–speed` +`--speed` float, default: 1.0 @@ -45,28 +45,41 @@ string, default: "tests/test_en_1_ref_short.wav" Provide a reference audio file path to help guide the generation. -`–-ref-text` +`--ref-text` string, default: "Some call me nature, others call me mother nature." Provide a caption for the reference audio. -`-–output` +`--output` string, default: "output.wav" Specify the output path where the generated audio will be saved. If not specified, the script will save the output to a default location. +`--cfg` -`-–sway-coef` +float, default: 2.0 -float, default: 0.0 +Specifies the strength used for classifier free guidance + + +`--steps` + +int, default: 32 + +Specify the number of steps used to sample the neural ODE. Lower steps trade off quality for latency. + + +`--sway-coef` + +float, default: -1.0 Set the sway sampling coefficient. The best values according to the paper are in the range of [-1.0...1.0]. -`-–seed` +`--seed` int, default: None (random) diff --git a/examples/generate.py b/examples/generate.py index 1272c4a..c98dd61 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -27,6 +27,7 @@ def generate( model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, + steps: int = 32, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, # used when duration is None as part of the duration heuristic @@ -83,7 +84,7 @@ def generate( mx.expand_dims(audio, axis=0), text=text, duration=frame_duration, - steps=32, + steps=steps, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, @@ -138,6 +139,12 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of steps to take when sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -171,6 +178,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + steps=args.steps, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed, diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index bd32c56..95ee68c 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -479,4 +479,4 @@ def from_pretrained( f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters()) - return f5tts \ No newline at end of file + return f5tts diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index 1272c4a..c98dd61 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -27,6 +27,7 @@ def generate( model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, + steps: int = 32, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, # used when duration is None as part of the duration heuristic @@ -83,7 +84,7 @@ def generate( mx.expand_dims(audio, axis=0), text=text, duration=frame_duration, - steps=32, + steps=steps, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, @@ -138,6 +139,12 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of steps to take when sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -171,6 +178,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + steps=args.steps, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed, From d13ca5ae06a2df5aff1d4ae587dd04360791b608 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 08:48:21 -0700 Subject: [PATCH 12/18] 0.1.2. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a2df269..0aeebbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts-mlx" -version = "0.1.0" +version = "0.1.2" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "F5-TTS - MLX" From dfb96c90c92d3a210082ea0353955bc19cc12ff8 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 09:50:03 -0700 Subject: [PATCH 13/18] Add a trainer for the main model. --- f5_tts_mlx/cfm.py | 15 +---- f5_tts_mlx/trainer.py | 139 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 12 deletions(-) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 95ee68c..81f3a29 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -159,11 +159,6 @@ def __init__( self, transformer: nn.Module, sigma=0.0, - odeint_kwargs: dict = dict( - # atol = 1e-5, - # rtol = 1e-5, - method="euler" # 'midpoint' - ), audio_drop_prob=0.3, cond_drop_prob=0.2, num_channels=None, @@ -193,9 +188,6 @@ def __init__( # conditional flow related self.sigma = sigma - # sampling related - self.odeint_kwargs = odeint_kwargs - # vocab map for tokenization self.vocab_char_map = vocab_char_map @@ -230,7 +222,7 @@ def __call__( # get a random span to mask out for training conditionally frac_lengths = mx.random.uniform(*self.frac_lengths_mask, (batch,)) - rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len) if exists(mask): rand_span_mask = rand_span_mask & mask @@ -282,7 +274,7 @@ def __call__( masked_loss = mx.where(rand_span_mask, loss, mx.zeros_like(loss)) loss = mx.sum(masked_loss) / mx.maximum(mx.sum(rand_span_mask), 1e-6) - return loss.mean(), cond + return loss.mean() def odeint(self, func, y0, t, **kwargs): """ @@ -326,7 +318,6 @@ def sample( max_duration=4096, vocoder: Callable[[mx.array["b d n"]], mx.array["b nw"]] | None = None, no_ref_audio=False, - t_inter=0.1, edit_mask=None, ) -> tuple[mx.array, mx.array]: self.eval() @@ -435,7 +426,7 @@ def fn(t, x): if exists(sway_sampling_coef): t = t + sway_sampling_coef * (mx.cos(mx.pi / 2 * t) - 1 + t) - trajectory = self.odeint(fn, y0, t, **self.odeint_kwargs) + trajectory = self.odeint(fn, y0, t) sampled = trajectory[-1] out = sampled diff --git a/f5_tts_mlx/trainer.py b/f5_tts_mlx/trainer.py index 8208bfe..ea0c282 100644 --- a/f5_tts_mlx/trainer.py +++ b/f5_tts_mlx/trainer.py @@ -171,3 +171,142 @@ def train_step(mel_spec, text_inputs, mel_lens): wandb.finish() print(f"Training complete in {datetime.datetime.now() - training_start_date}") + +class F5TTSTrainer: + def __init__( + self, + model: DurationPredictor, + num_warmup_steps=1000, + max_grad_norm=1.0, + sample_rate=24_000, + log_with_wandb=False, + ): + self.model = model + self.num_warmup_steps = num_warmup_steps + self.mel_spectrogram = MelSpec(sample_rate=sample_rate) + self.max_grad_norm = max_grad_norm + self.log_with_wandb = log_with_wandb + + def save_checkpoint(self, step, finetune=False): + mx.save_safetensors( + f"f5tts_{step}", + dict(tree_flatten(self.model.trainable_parameters())), + ) + + def load_checkpoint(self, step): + params = mx.load(f"f5tts_{step}.saftensors") + self.model.load_weights(params) + self.model.eval() + + def train( + self, + train_dataset, + learning_rate=1e-4, + weight_decay=1e-2, + total_steps=100_000, + batch_size=8, + log_every=10, + save_every=1000, + checkpoint: int | None = None, + ): + if self.log_with_wandb: + wandb.init( + project="f5tts", + config=dict( + learning_rate=learning_rate, + total_steps=total_steps, + batch_size=batch_size, + ), + ) + + decay_steps = total_steps - self.num_warmup_steps + + warmup_scheduler = linear_schedule( + init=1e-8, + end=learning_rate, + steps=self.num_warmup_steps, + ) + decay_scheduler = cosine_decay(init=learning_rate, decay_steps=decay_steps) + scheduler = join_schedules( + schedules=[warmup_scheduler, decay_scheduler], + boundaries=[self.num_warmup_steps], + ) + self.optimizer = AdamW(learning_rate=scheduler, weight_decay=weight_decay) + + if checkpoint is not None: + self.load_checkpoint(checkpoint) + start_step = checkpoint + else: + start_step = 0 + + global_step = start_step + + def loss_fn(model: F5TTS, mel_spec, text, lens): + return model(mel_spec, text=text, lens=lens) + + # state = [self.model.state, self.optimizer.state, mx.random.state] + + # @partial(mx.compile, inputs=state, outputs=state) + def train_step(mel_spec, text_inputs, mel_lens): + loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) + loss, grads = loss_and_grad_fn( + self.model, mel_spec, text=text_inputs, lens=mel_lens + ) + + if self.max_grad_norm > 0: + grads, _ = clip_grad_norm(grads, max_norm=self.max_grad_norm) + + self.optimizer.update(self.model, grads) + + return loss + + training_start_date = datetime.datetime.now() + log_start_date = datetime.datetime.now() + + batched_dataset = ( + train_dataset.repeat(1_000_000) # repeat indefinitely + .shuffle(1000) + .prefetch(prefetch_size=batch_size, num_threads=4) + .batch(batch_size) + ) + + for batch in batched_dataset: + effective_batch_size = batch["transcript"].shape[0] + text_inputs = [ + bytes(batch["transcript"][i]).decode("utf-8") + for i in range(effective_batch_size) + ] + + mel_spec = rearrange(mx.array(batch["mel_spec"]), "b 1 n c -> b n c") + mel_lens = mx.array(batch["mel_len"], dtype=mx.int32) + + loss = train_step(mel_spec, text_inputs, mel_lens) + # mx.eval(state) + mx.eval(self.model.parameters(), self.optimizer.state) + + if self.log_with_wandb: + wandb.log( + {"loss": loss.item(), "lr": self.optimizer.learning_rate.item()}, + step=global_step, + ) + + if global_step > 0 and global_step % log_every == 0: + elapsed_time = datetime.datetime.now() - log_start_date + log_start_date = datetime.datetime.now() + + print( + f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(elapsed_time.seconds / log_every):.2f}" + ) + + global_step += 1 + + if global_step % save_every == 0: + self.save_checkpoint(global_step) + + if global_step >= total_steps: + break + + if self.log_with_wandb: + wandb.finish() + + print(f"Training complete in {datetime.datetime.now() - training_start_date}") From 4dd52da1916384fc386cf5f39387f922a1c8f77d Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 11:54:31 -0700 Subject: [PATCH 14/18] Switch to the duration predictor model. --- f5_tts_mlx/cfm.py | 189 ++++++++++++----------------------------- f5_tts_mlx/duration.py | 70 ++++----------- f5_tts_mlx/generate.py | 20 +---- f5_tts_mlx/utils.py | 150 ++++++++++++++++++++++++++++++++ 4 files changed, 226 insertions(+), 203 deletions(-) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 81f3a29..1db7287 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -15,141 +15,21 @@ import mlx.core as mx import mlx.nn as nn -from einops.array_api import rearrange, reduce, repeat -import einx +from einops.array_api import rearrange, repeat +from f5_tts_mlx.duration import DurationPredictor, DurationTransformer from f5_tts_mlx.dit import DiT from f5_tts_mlx.modules import MelSpec - -from huggingface_hub import snapshot_download - - -def fetch_from_hub(hf_repo: str) -> Path: - model_path = Path( - snapshot_download( - repo_id=hf_repo, - allow_patterns=["*.safetensors", "*.txt"], - ) - ) - return model_path - - -def exists(v): - return v is not None - - -def default(v, d): - return v if exists(v) else d - - -def divisible_by(num, den): - return (num % den) == 0 - - -def lens_to_mask( - t: mx.array, - length: int | None = None, -) -> mx.array: # Bool['b n'] - if not exists(length): - length = t.max() - - seq = mx.arange(length) - return einx.less("n, b -> b n", seq, t) - - -def mask_from_start_end_indices( - seq_len: mx.array, - start: mx.array, - end: mx.array, - max_length: int | None = None, -): - max_seq_len = default(max_length, seq_len.max().item()) - seq = mx.arange(max_seq_len).astype(mx.int32) - return einx.greater_equal("n, b -> b n", seq, start) & einx.less( - "n, b -> b n", seq, end - ) - - -def mask_from_frac_lengths( - seq_len: mx.array, - frac_lengths: mx.array, - max_length: int | None = None, -): - lengths = (frac_lengths * seq_len).astype(mx.int32) - max_start = seq_len - lengths - - rand = mx.random.uniform(0, 1, frac_lengths.shape) - - start = mx.maximum((max_start * rand).astype(mx.int32), 0) - end = start + lengths - - out = mask_from_start_end_indices(seq_len, start, end, max_length) - - if exists(max_length): - out = pad_to_length(out, max_length) - - return out - - -def maybe_masked_mean(t: mx.array, mask: mx.array | None = None) -> mx.array: - if not exists(mask): - return t.mean(dim=1) - - t = einx.where("b n, b n d, -> b n d", mask, t, 0.0) - num = reduce(t, "b n d -> b d", "sum") - den = reduce(mask.astype(mx.int32), "b n -> b", "sum") - - return einx.divide("b d, b -> b d", num, mx.maximum(den, 1)) - - -def pad_to_length(t: mx.array, length: int, value=None): - ndim = t.ndim - seq_len = t.shape[-1] - if length > seq_len: - if ndim == 1: - t = mx.pad(t, [(0, length - seq_len)], constant_values=value) - elif ndim == 2: - t = mx.pad(t, [(0, 0), (0, length - seq_len)], constant_values=value) - elif ndim == 3: - t = mx.pad( - t, [(0, 0), (0, length - seq_len), (0, 0)], constant_values=value - ) - else: - raise ValueError(f"Unsupported padding dims: {ndim}") - return t[..., :length] - - -def pad_sequence(t: mx.array, padding_value=0): - max_len = max([i.shape[-1] for i in t]) - t = mx.array([pad_to_length(i, max_len, padding_value) for i in t]) - return t - - -# simple utf-8 tokenizer, since paper went character based - - -def list_str_to_tensor(text: list[str], padding_value=-1) -> mx.array: # Int['b nt']: - list_tensors = [mx.array([*bytes(t, "UTF-8")]) for t in text] - padded_tensor = pad_sequence(list_tensors, padding_value=-1) - return padded_tensor - - -# char tokenizer, based on custom dataset's extracted .txt file - - -def list_str_to_idx( - text: list[str], - vocab_char_map: dict[str, int], # {char: idx} - padding_value=-1, -) -> mx.array: # Int['b nt']: - list_idx_tensors = [ - [vocab_char_map.get(c, 0) for c in t] for t in text - ] # pinyin or char style - - list_idx_tensors = [mx.array(t) for t in list_idx_tensors] - text = pad_sequence(list_idx_tensors, padding_value=padding_value) - return text - +from f5_tts_mlx.utils import ( + exists, + default, + lens_to_mask, + mask_from_frac_lengths, + list_str_to_idx, + list_str_to_tensor, + pad_sequence, + fetch_from_hub, +) # conditional flow matching @@ -166,6 +46,7 @@ def __init__( mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.0), vocab_char_map: dict[str, int] | None = None, + duration_predictor: DurationPredictor | None = None, ): super().__init__() @@ -191,6 +72,9 @@ def __init__( # vocab map for tokenization self.vocab_char_map = vocab_char_map + # duration predictor (optional) + self._duration_predictor = duration_predictor + def __call__( self, inp: mx.array["b n d"] | mx.array["b nw"], # mel or raw wave @@ -308,11 +192,12 @@ def sample( self, cond: mx.array["b n d"] | mx.array["b nw"], text: mx.array["b nt"] | list[str], - duration: int | mx.array["b"], + duration: int | mx.array["b"] | None = None, *, lens: mx.array["b"] | None = None, steps=32, cfg_strength=1.0, + speed=1.0, sway_sampling_coef=None, seed: int | None = None, max_duration=4096, @@ -349,6 +234,17 @@ def sample( # duration + if duration is None and self._duration_predictor is not None: + duration_in_sec = self._duration_predictor(cond, text) + frame_rate = self.mel_spec.sample_rate // self.mel_spec.hop_length + duration = (duration_in_sec * frame_rate / speed).astype(mx.int32).item() + print(f"Got duration of {duration} frames ({duration_in_sec.item()} secs) for generated speech.") + + # include the reference audio length + duration = duration + lens + elif duration is None: + raise ValueError("Duration must be provided or a duration predictor must be set.") + cond_mask = lens_to_mask(lens) if edit_mask is not None: cond_mask = cond_mask & edit_mask @@ -449,9 +345,31 @@ def from_pretrained( if path is None: raise ValueError(f"Could not find model {hf_model_name_or_path}") - model_path = path / "model.safetensors" vocab_path = path / "vocab.txt" vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))} + if len(vocab) == 0: + raise ValueError(f"Could not load vocab from {vocab_path}") + + duration_model_path = path / "duration_model.safetensors" + duration_predictor = None + + if duration_model_path.exists(): + duration_predictor = DurationPredictor( + transformer=DurationTransformer( + dim=256, + depth=8, + heads=8, + text_dim=256, + ff_mult=2, + conv_layers=4, + text_num_embeds=len(vocab) - 1, + ), + vocab_char_map=vocab, + ) + weights = mx.load(duration_model_path.as_posix(), format="safetensors") + duration_predictor.load_weights(list(weights.items())) + + model_path = path / "model.safetensors" f5tts = F5TTS( transformer=DiT( @@ -464,8 +382,9 @@ def from_pretrained( text_num_embeds=len(vocab) - 1, ), vocab_char_map=vocab, + duration_predictor=duration_predictor, ) - + weights = mx.load(model_path.as_posix(), format="safetensors") f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters()) diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index 8b5f286..ae159f0 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -1,3 +1,12 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + from __future__ import annotations import mlx.core as mx @@ -6,33 +15,27 @@ from einops.array_api import rearrange, repeat import einx -from f5_tts_mlx.cfm import ( +from f5_tts_mlx.dit import TextEmbedding, TimestepEmbedding, ConvPositionEmbedding +from f5_tts_mlx.modules import ( + MelSpec, + RotaryEmbedding, + DiTBlock, +) +from f5_tts_mlx.utils import ( + exists, + default, list_str_to_idx, list_str_to_tensor, lens_to_mask, maybe_masked_mean, ) -from f5_tts_mlx.dit import DiT, TextEmbedding, TimestepEmbedding, ConvPositionEmbedding -from f5_tts_mlx.modules import ( - MelSpec, - RotaryEmbedding, - DiTBlock, -) SAMPLE_RATE = 24_000 HOP_LENGTH = 256 SAMPLES_PER_SECOND = SAMPLE_RATE / HOP_LENGTH -def exists(v): - return v is not None - - -def default(v, d): - return v if exists(v) else d - - class Rearrange(nn.Module): def __init__(self, pattern: str): super().__init__() @@ -217,40 +220,3 @@ def __call__( duration = lens.astype(mx.float32) / SAMPLES_PER_SECOND return nn.losses.mse_loss(pred, duration) - - def sample( - self, - cond: mx.array["b n d"] | mx.array["b nw"], - text: mx.array["b nt"] | list[str], - *, - lens: mx.array["b"] | None = None, - max_duration=4096, - ) -> tuple[mx.array, mx.array]: - self.eval() - - # raw wave - - if cond.ndim == 2: - cond = rearrange(cond, "1 n -> n") - cond = self.mel_spec(cond) - assert cond.shape[-1] == self.num_channels - - batch, cond_seq_len, dtype = *cond.shape[:2], cond.dtype - if not exists(lens): - lens = mx.full((batch,), cond_seq_len, dtype=dtype) - - # text - - if isinstance(text, list): - if exists(self.vocab_char_map): - text = list_str_to_idx(text, self.vocab_char_map) - else: - text = list_str_to_tensor(text) - assert text.shape[0] == batch - - if exists(text): - text_lens = (text != -1).sum(axis=-1) - lens = mx.maximum(text_lens, lens) - - pred = self.transformer(cond, text=text) - return mx.minimum(max_duration / SAMPLES_PER_SECOND, pred) diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index c98dd61..cbf4a72 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -30,7 +30,7 @@ def generate( steps: int = 32, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, - speed: float = 1.0, # used when duration is None as part of the duration heuristic + speed: float = 1.0, # used when duration is None as part of the duration heuristic seed: Optional[int] = None, output_path: str = "output.wav", ): @@ -63,19 +63,6 @@ def generate( # generate the audio for the given text text = convert_char_to_pinyin([ref_audio_text + " " + generation_text]) - - # use a heuristic to determine the duration if not provided - if duration is None: - ref_audio_len = audio.shape[0] // HOP_LENGTH - zh_pause_punc = r"。,、;:?!" - ref_text_len = len(ref_audio_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_audio_text)) - gen_text_len = len(generation_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, generation_text)) - duration_in_frames = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) - duration = (duration_in_frames / FRAMES_PER_SEC) - ref_audio_duration - print(f"Using duration of {duration:.2f} seconds for generated speech.") - - frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC) - print(f"Generating {frame_duration} total frames of audio...") start_date = datetime.datetime.now() vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") @@ -83,8 +70,9 @@ def generate( wave, _ = f5tts.sample( mx.expand_dims(audio, axis=0), text=text, - duration=frame_duration, + duration=None, steps=steps, + speed=speed, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, @@ -92,7 +80,7 @@ def generate( ) # trim the reference audio - wave = wave[audio.shape[0]:] + wave = wave[audio.shape[0] :] generated_duration = wave.shape[0] / SAMPLE_RATE elapsed_time = datetime.datetime.now() - start_date diff --git a/f5_tts_mlx/utils.py b/f5_tts_mlx/utils.py index 631c20c..263e7f0 100644 --- a/f5_tts_mlx/utils.py +++ b/f5_tts_mlx/utils.py @@ -1,6 +1,143 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +from pathlib import Path + +import mlx.core as mx + +from einops.array_api import reduce +import einx + +from huggingface_hub import snapshot_download + import jieba from pypinyin import lazy_pinyin, Style + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +def divisible_by(num, den): + return (num % den) == 0 + + +def lens_to_mask( + t: mx.array, + length: int | None = None, +) -> mx.array: # Bool['b n'] + if not exists(length): + length = t.max() + + seq = mx.arange(length) + return einx.less("n, b -> b n", seq, t) + + +def mask_from_start_end_indices( + seq_len: mx.array, + start: mx.array, + end: mx.array, + max_length: int | None = None, +): + max_seq_len = default(max_length, seq_len.max().item()) + seq = mx.arange(max_seq_len).astype(mx.int32) + return einx.greater_equal("n, b -> b n", seq, start) & einx.less( + "n, b -> b n", seq, end + ) + + +def mask_from_frac_lengths( + seq_len: mx.array, + frac_lengths: mx.array, + max_length: int | None = None, +): + lengths = (frac_lengths * seq_len).astype(mx.int32) + max_start = seq_len - lengths + + rand = mx.random.uniform(0, 1, frac_lengths.shape) + + start = mx.maximum((max_start * rand).astype(mx.int32), 0) + end = start + lengths + + out = mask_from_start_end_indices(seq_len, start, end, max_length) + + if exists(max_length): + out = pad_to_length(out, max_length) + + return out + + +def maybe_masked_mean(t: mx.array, mask: mx.array | None = None) -> mx.array: + if not exists(mask): + return t.mean(dim=1) + + t = einx.where("b n, b n d, -> b n d", mask, t, 0.0) + num = reduce(t, "b n d -> b d", "sum") + den = reduce(mask.astype(mx.int32), "b n -> b", "sum") + + return einx.divide("b d, b -> b d", num, mx.maximum(den, 1)) + + +def pad_to_length(t: mx.array, length: int, value=None): + ndim = t.ndim + seq_len = t.shape[-1] + if length > seq_len: + if ndim == 1: + t = mx.pad(t, [(0, length - seq_len)], constant_values=value) + elif ndim == 2: + t = mx.pad(t, [(0, 0), (0, length - seq_len)], constant_values=value) + elif ndim == 3: + t = mx.pad( + t, [(0, 0), (0, length - seq_len), (0, 0)], constant_values=value + ) + else: + raise ValueError(f"Unsupported padding dims: {ndim}") + return t[..., :length] + + +def pad_sequence(t: mx.array, padding_value=0): + max_len = max([i.shape[-1] for i in t]) + t = mx.array([pad_to_length(i, max_len, padding_value) for i in t]) + return t + + +# simple utf-8 tokenizer, since paper went character based + + +def list_str_to_tensor(text: list[str], padding_value=-1) -> mx.array: # Int['b nt']: + list_tensors = [mx.array([*bytes(t, "UTF-8")]) for t in text] + padded_tensor = pad_sequence(list_tensors, padding_value=-1) + return padded_tensor + + +# char tokenizer, based on custom dataset's extracted .txt file + + +def list_str_to_idx( + text: list[str], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +) -> mx.array: # Int['b nt']: + list_idx_tensors = [ + [vocab_char_map.get(c, 0) for c in t] for t in text + ] # pinyin or char style + + list_idx_tensors = [mx.array(t) for t in list_idx_tensors] + text = pad_sequence(list_idx_tensors, padding_value=padding_value) + return text + + # convert char to pinyin @@ -43,3 +180,16 @@ def convert_char_to_pinyin(text_list, polyphone=True): final_text_list.append(char_list) return final_text_list + + +# fetch model from hub + + +def fetch_from_hub(hf_repo: str) -> Path: + model_path = Path( + snapshot_download( + repo_id=hf_repo, + allow_patterns=["*.safetensors", "*.txt"], + ) + ) + return model_path From 7e5540d88cf586c20da65fa22dd39570f086078d Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 15:02:11 -0700 Subject: [PATCH 15/18] Fix duration from predictor. --- f5_tts_mlx/cfm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 1db7287..0589606 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -239,9 +239,6 @@ def sample( frame_rate = self.mel_spec.sample_rate // self.mel_spec.hop_length duration = (duration_in_sec * frame_rate / speed).astype(mx.int32).item() print(f"Got duration of {duration} frames ({duration_in_sec.item()} secs) for generated speech.") - - # include the reference audio length - duration = duration + lens elif duration is None: raise ValueError("Duration must be provided or a duration predictor must be set.") From badbe8c270bf530fb04e3cb5a290a2c6f538e133 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 15:02:41 -0700 Subject: [PATCH 16/18] 0.1.3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0aeebbd..f98346e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts-mlx" -version = "0.1.2" +version = "0.1.3" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "F5-TTS - MLX" From 56dfd08282c8a245038bed00909d89f967d85b5d Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 22 Oct 2024 09:24:51 -0700 Subject: [PATCH 17/18] Add euler ODE sampling for faster generation, and include the vocoder in the pretrained model loading stage. --- examples/generate.py | 42 +++++++--------- f5_tts_mlx/cfm.py | 112 ++++++++++++++++++++++++++++++----------- f5_tts_mlx/dit.py | 2 +- f5_tts_mlx/duration.py | 2 +- f5_tts_mlx/generate.py | 28 +++++++---- 5 files changed, 121 insertions(+), 65 deletions(-) diff --git a/examples/generate.py b/examples/generate.py index c98dd61..4422984 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,8 +1,7 @@ import argparse import datetime import pkgutil -import re -from typing import Optional +from typing import Literal, Optional import mlx.core as mx @@ -11,8 +10,6 @@ from f5_tts_mlx.cfm import F5TTS from f5_tts_mlx.utils import convert_char_to_pinyin -from vocos_mlx import Vocos - import soundfile as sf SAMPLE_RATE = 24_000 @@ -28,9 +25,10 @@ def generate( ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, steps: int = 32, + method: Literal["euler", "midpoint"] = "euler", cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, - speed: float = 1.0, # used when duration is None as part of the duration heuristic + speed: float = 0.8, # used when duration is None as part of the duration heuristic seed: Optional[int] = None, output_path: str = "output.wav", ): @@ -63,36 +61,26 @@ def generate( # generate the audio for the given text text = convert_char_to_pinyin([ref_audio_text + " " + generation_text]) - - # use a heuristic to determine the duration if not provided - if duration is None: - ref_audio_len = audio.shape[0] // HOP_LENGTH - zh_pause_punc = r"。,、;:?!" - ref_text_len = len(ref_audio_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_audio_text)) - gen_text_len = len(generation_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, generation_text)) - duration_in_frames = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) - duration = (duration_in_frames / FRAMES_PER_SEC) - ref_audio_duration - print(f"Using duration of {duration:.2f} seconds for generated speech.") - - frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC) - print(f"Generating {frame_duration} total frames of audio...") start_date = datetime.datetime.now() - vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") + + if duration is not None: + duration = int(duration * FRAMES_PER_SEC) wave, _ = f5tts.sample( mx.expand_dims(audio, axis=0), text=text, - duration=frame_duration, + duration=duration, steps=steps, + method=method, + speed=speed, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, - vocoder=vocos.decode, ) # trim the reference audio - wave = wave[audio.shape[0]:] + wave = wave[audio.shape[0] :] generated_duration = wave.shape[0] / SAMPLE_RATE elapsed_time = datetime.datetime.now() - start_date @@ -145,6 +133,13 @@ def generate( default=32, help="Number of steps to take when sampling the neural ODE", ) + parser.add_argument( + "--method", + type=str, + default="euler", + choices=["euler", "midpoint"], + help="Method to use for sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -160,7 +155,7 @@ def generate( parser.add_argument( "--speed", type=float, - default=1.0, + default=0.8, help="Speed factor for the duration heuristic", ) parser.add_argument( @@ -179,6 +174,7 @@ def generate( ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, steps=args.steps, + method=args.method, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed, diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 0589606..e7d7293 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -8,15 +8,18 @@ """ from __future__ import annotations +from datetime import datetime from pathlib import Path from random import random -from typing import Callable +from typing import Callable, Literal import mlx.core as mx import mlx.nn as nn from einops.array_api import rearrange, repeat +from vocos_mlx import Vocos + from f5_tts_mlx.duration import DurationPredictor, DurationTransformer from f5_tts_mlx.dit import DiT from f5_tts_mlx.modules import MelSpec @@ -46,6 +49,7 @@ def __init__( mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.0), vocab_char_map: dict[str, int] | None = None, + vocoder: Callable[[mx.array["b d n"]], mx.array["b nw"]] | None = None, duration_predictor: DurationPredictor | None = None, ): super().__init__() @@ -72,6 +76,9 @@ def __init__( # vocab map for tokenization self.vocab_char_map = vocab_char_map + # vocoder (optional) + self._vocoder = vocoder + # duration predictor (optional) self._duration_predictor = duration_predictor @@ -106,7 +113,7 @@ def __call__( # get a random span to mask out for training conditionally frac_lengths = mx.random.uniform(*self.frac_lengths_mask, (batch,)) - rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length = seq_len) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, max_length=seq_len) if exists(mask): rand_span_mask = rand_span_mask & mask @@ -160,7 +167,7 @@ def __call__( return loss.mean() - def odeint(self, func, y0, t, **kwargs): + def odeint_midpoint(self, func, y0, t): """ Solves ODE using the midpoint method. @@ -188,6 +195,30 @@ def odeint(self, func, y0, t, **kwargs): return mx.stack(ys) + def odeint_euler(self, func, y0, t): + """ + Solves ODE using the Euler method. + + Parameters: + - y0: Initial state, an MLX array of any shape. + - t: Array of time steps, an MLX array. + """ + ys = [y0] + y_current = y0 + + for i in range(len(t) - 1): + t_current = t[i] + dt = t[i + 1] - t_current + + # compute the next value + k = func(t_current, y_current) + y_next = y_current + dt * k + + ys.append(y_next) + y_current = y_next + + return mx.stack(ys) + def sample( self, cond: mx.array["b n d"] | mx.array["b nw"], @@ -196,15 +227,17 @@ def sample( *, lens: mx.array["b"] | None = None, steps=32, - cfg_strength=1.0, + method: Literal["euler", "midpoint"] = "euler", + cfg_strength=2.0, speed=1.0, - sway_sampling_coef=None, + sway_sampling_coef=-1.0, seed: int | None = None, max_duration=4096, - vocoder: Callable[[mx.array["b d n"]], mx.array["b nw"]] | None = None, no_ref_audio=False, edit_mask=None, ) -> tuple[mx.array, mx.array]: + start_date = datetime.now() + self.eval() # raw wave @@ -238,9 +271,13 @@ def sample( duration_in_sec = self._duration_predictor(cond, text) frame_rate = self.mel_spec.sample_rate // self.mel_spec.hop_length duration = (duration_in_sec * frame_rate / speed).astype(mx.int32).item() - print(f"Got duration of {duration} frames ({duration_in_sec.item()} secs) for generated speech.") + print( + f"Got duration of {duration} frames ({duration_in_sec.item()} secs) for generated speech." + ) elif duration is None: - raise ValueError("Duration must be provided or a duration predictor must be set.") + raise ValueError( + "Duration must be provided or a duration predictor must be set." + ) cond_mask = lens_to_mask(lens) if edit_mask is not None: @@ -260,9 +297,10 @@ def sample( constant_values=False, ) cond_mask = rearrange(cond_mask, "... -> ... 1") - step_cond = mx.where( - cond_mask, cond, mx.zeros_like(cond) - ) # allow direct control (cut cond audio) with lens passed in + + # at each step, conditioning is fixed + + step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond)) if batch > 1: mask = lens_to_mask(duration) @@ -276,9 +314,6 @@ def sample( # neural ode def fn(t, x): - # at each step, conditioning is fixed - # step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond)) - # predict flow pred = self.transformer( x=x, @@ -290,6 +325,7 @@ def fn(t, x): drop_text=False, ) if cfg_strength < 1e-5: + mx.eval(pred) return pred null_pred = self.transformer( @@ -301,11 +337,12 @@ def fn(t, x): drop_audio_cond=True, drop_text=True, ) - return pred + (pred - null_pred) * cfg_strength + output = pred + (pred - null_pred) * cfg_strength + mx.eval(output) + return output # noise input - # to make sure batch inference result is same with different batch size, and for sure single inference - # still some difference maybe due to convolutional layers + y0 = [] for dur in duration: if exists(seed): @@ -319,37 +356,45 @@ def fn(t, x): if exists(sway_sampling_coef): t = t + sway_sampling_coef * (mx.cos(mx.pi / 2 * t) - 1 + t) - trajectory = self.odeint(fn, y0, t) + if method == "midpoint": + trajectory = self.odeint_midpoint(fn, y0, t) + elif method == "euler": + trajectory = self.odeint_euler(fn, y0, t) + else: + raise ValueError(f"Unknown method: {method}") sampled = trajectory[-1] out = sampled out = mx.where(cond_mask, cond, out) - if exists(vocoder): - out = vocoder(out) - + if exists(self._vocoder): + out = self._vocoder(out) + mx.eval(out) + print(f"Generated speech in {datetime.now() - start_date}") + return out, trajectory @classmethod - def from_pretrained( - cls, - hf_model_name_or_path: str - ) -> F5TTS: + def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS: path = fetch_from_hub(hf_model_name_or_path) - + if path is None: raise ValueError(f"Could not find model {hf_model_name_or_path}") + # vocab + vocab_path = path / "vocab.txt" vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))} if len(vocab) == 0: raise ValueError(f"Could not load vocab from {vocab_path}") + # duration predictor + duration_model_path = path / "duration_model.safetensors" duration_predictor = None - + if duration_model_path.exists(): duration_predictor = DurationPredictor( transformer=DurationTransformer( @@ -365,9 +410,15 @@ def from_pretrained( ) weights = mx.load(duration_model_path.as_posix(), format="safetensors") duration_predictor.load_weights(list(weights.items())) - + + # vocoder + + vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") + + # model + model_path = path / "model.safetensors" - + f5tts = F5TTS( transformer=DiT( dim=1024, @@ -379,11 +430,12 @@ def from_pretrained( text_num_embeds=len(vocab) - 1, ), vocab_char_map=vocab, + vocoder=vocos.decode, duration_predictor=duration_predictor, ) weights = mx.load(model_path.as_posix(), format="safetensors") f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters()) - + return f5tts diff --git a/f5_tts_mlx/dit.py b/f5_tts_mlx/dit.py index aa57b06..a30a1bc 100644 --- a/f5_tts_mlx/dit.py +++ b/f5_tts_mlx/dit.py @@ -114,7 +114,7 @@ def __init__( depth=8, heads=8, dim_head=64, - dropout=0.1, + dropout=0.0, ff_mult=4, mel_dim=100, text_num_embeds=256, diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index ae159f0..eafddd7 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -69,7 +69,7 @@ def __init__( depth=8, heads=8, dim_head=64, - dropout=0.1, + dropout=0.0, ff_mult=4, mel_dim=100, text_num_embeds=256, diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index cbf4a72..1541f6b 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -1,8 +1,7 @@ import argparse import datetime import pkgutil -import re -from typing import Optional +from typing import Literal, Optional import mlx.core as mx @@ -11,8 +10,6 @@ from f5_tts_mlx.cfm import F5TTS from f5_tts_mlx.utils import convert_char_to_pinyin -from vocos_mlx import Vocos - import soundfile as sf SAMPLE_RATE = 24_000 @@ -28,9 +25,10 @@ def generate( ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, steps: int = 32, + method: Literal["euler", "midpoint"] = "euler", cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, - speed: float = 1.0, # used when duration is None as part of the duration heuristic + speed: float = 0.8, # used when duration is None as part of the duration heuristic seed: Optional[int] = None, output_path: str = "output.wav", ): @@ -65,18 +63,20 @@ def generate( text = convert_char_to_pinyin([ref_audio_text + " " + generation_text]) start_date = datetime.datetime.now() - vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") + + if duration is not None: + duration = int(duration * FRAMES_PER_SEC) wave, _ = f5tts.sample( mx.expand_dims(audio, axis=0), text=text, - duration=None, + duration=duration, steps=steps, + method=method, speed=speed, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, - seed=seed, - vocoder=vocos.decode, + seed=seed ) # trim the reference audio @@ -133,6 +133,13 @@ def generate( default=32, help="Number of steps to take when sampling the neural ODE", ) + parser.add_argument( + "--method", + type=str, + default="euler", + choices=["euler", "midpoint"], + help="Method to use for sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -148,7 +155,7 @@ def generate( parser.add_argument( "--speed", type=float, - default=1.0, + default=0.8, help="Speed factor for the duration heuristic", ) parser.add_argument( @@ -167,6 +174,7 @@ def generate( ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, steps=args.steps, + method=args.method, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed, From 2d66562010c5657b31aee654473bd327f4bc2405 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Tue, 22 Oct 2024 09:25:20 -0700 Subject: [PATCH 18/18] 0.1.4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f98346e..69f2a05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts-mlx" -version = "0.1.3" +version = "0.1.4" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "F5-TTS - MLX"