Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MultiBandDiffusion #109

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Google Colab demo: [![Open In Colab](https://colab.research.google.com/assets/co

## Videos

| **Refining Bark TTS vocals using Demucs & Vocos** | **Demo - How to use RVC with Tortoise** | **How To Get More Voices for Bark TTS** |
| **TTS Generation WebUI - A Tool for Text to Speech and Voice Cloning** | **Text to speech and voice cloning - TTS Generation WebUI** | **AudioGen Unveils New Text-to-Audio Capabilities** |
| :------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------: |
| [![Watch the video](https://img.youtube.com/vi/jCb-8JE7pk8/sddefault.jpg)](https://youtu.be/jCb-8JE7pk8) | [![Watch the video](https://img.youtube.com/vi/mhp_e8WSpxA/sddefault.jpg)](https://youtu.be/mhp_e8WSpxA) | [![Watch the video](https://img.youtube.com/vi/yeC5vJoavOE/sddefault.jpg)](https://youtu.be/yeC5vJoavOE) |
| [![Watch the video](https://img.youtube.com/vi/JXojhFjZ39k/sddefault.jpg)](https://youtu.be/JXojhFjZ39k) | [![Watch the video](https://img.youtube.com/vi/ScN2ypewABc/sddefault.jpg)](https://youtu.be/ScN2ypewABc) | [![Watch the video](https://img.youtube.com/vi/fDqyw9JG6PY/sddefault.jpg)](https://youtu.be/fDqyw9JG6PY) |

## Screenshots

Expand All @@ -34,6 +34,10 @@ Google Colab demo: [![Open In Colab](https://colab.research.google.com/assets/co
https://rsxdalv.github.io/bark-speaker-directory/

## Changelog
Aug 4:
* Add MultiBandDiffusion option to MusicGen https://github.com/rsxdalv/tts-generation-webui/pull/109
* MusicGen/AudioGen save tokens on generation as .npz files.

Aug 3:
* Add AudioGen https://github.com/rsxdalv/tts-generation-webui/pull/105

Expand Down
18 changes: 15 additions & 3 deletions src/bark/npz_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from src.bark.FullGeneration import FullGeneration
import json
import torch


def compress_history(full_generation: FullGeneration):
Expand All @@ -13,10 +14,11 @@ def compress_history(full_generation: FullGeneration):
}


def save_npz(filename: str, full_generation: FullGeneration, metadata: dict[str, Any]):
def pack_metadata(metadata: dict[str, Any]):
return list(json.dumps(metadata))
def pack_metadata(metadata: dict[str, Any]):
return list(json.dumps(metadata))


def save_npz(filename: str, full_generation: FullGeneration, metadata: dict[str, Any]):
np.savez(
filename,
**{
Expand All @@ -26,6 +28,16 @@ def pack_metadata(metadata: dict[str, Any]):
)


def save_npz_musicgen(filename: str, tokens: torch.Tensor, metadata: dict[str, Any]):
np.savez(
filename,
**{
"tokens": tokens.cpu().numpy(),
"metadata": pack_metadata(metadata),
},
)


def load_npz(filename):
def unpack_metadata(metadata: np.ndarray):
def join_list(x: list | np.ndarray):
Expand Down
60 changes: 51 additions & 9 deletions src/musicgen/musicgen_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple, TypedDict
import numpy as np
import os
from src.bark.npz_tools import save_npz_musicgen
from src.musicgen.setup_seed_ui_musicgen import setup_seed_ui_musicgen
from src.bark.parse_or_set_seed import parse_or_set_seed
from src.musicgen.audio_array_to_sha256 import audio_array_to_sha256
Expand Down Expand Up @@ -37,6 +38,7 @@ class MusicGenGeneration(TypedDict):
temperature: float
cfg_coef: float
seed: int
use_multi_band_diffusion: bool


def melody_to_sha256(melody: Optional[Tuple[int, np.ndarray]]) -> Optional[str]:
Expand Down Expand Up @@ -76,13 +78,14 @@ def save_generation(
audio_array: np.ndarray,
SAMPLE_RATE: int,
params: MusicGenGeneration,
tokens: torch.Tensor,
):
prompt = params["text"]
date = get_date_string()
title = prompt[:20].replace(" ", "_")
base_filename = create_base_filename(title, "outputs", model="musicgen", date=date)

filename, filename_png, filename_json, _ = get_filenames(base_filename)
filename, filename_png, filename_json, filename_npz = get_filenames(base_filename)
write_wav(filename, SAMPLE_RATE, audio_array)
plot = save_waveform_plot(audio_array, filename_png)

Expand All @@ -93,6 +96,7 @@ def save_generation(
params=params,
audio_array=audio_array,
)
save_npz_musicgen(filename_npz, tokens, metadata)

filename_ogg = filename.replace(".wav", ".ogg")
ext_callback_save_generation_musicgen(
Expand Down Expand Up @@ -165,17 +169,19 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr
if melody.dim() == 2:
melody = melody[None]
melody = melody[..., : int(sr * MODEL.lm.cfg.dataset.segment_duration)] # type: ignore
output = MODEL.generate_with_chroma(
output, tokens = MODEL.generate_with_chroma(
descriptions=[text],
melody_wavs=melody,
melody_sample_rate=sr,
progress=False,
return_tokens=True,
# generator=generator,
)
else:
output = MODEL.generate(
output, tokens = MODEL.generate(
descriptions=[text],
progress=True,
return_tokens=True,
# generator=generator,
)
set_seed(-1)
Expand All @@ -184,12 +190,19 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr
# print time taken
print("Generated in", "{:.3f}".format(elapsed), "seconds")

output = output.detach().cpu().numpy().squeeze()
if params["use_multi_band_diffusion"]:
from audiocraft.models.multibanddiffusion import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
wav_diffusion = mbd.tokens_to_wav(tokens)
output = wav_diffusion.detach().cpu().numpy().squeeze()
else:
output = output.detach().cpu().numpy().squeeze()

filename, plot, _metadata = save_generation(
audio_array=output,
SAMPLE_RATE=MODEL.sample_rate,
params=params,
tokens=tokens,
)

return [
Expand All @@ -215,19 +228,31 @@ def generation_tab_musicgen():
"temperature": 1.0,
"cfg_coef": 3.0,
"seed": -1,
"use_multi_band_diffusion": False,
},
)
# musicgen_atom.render()
gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
with gr.Row():
with gr.Row(equal_height=False):
with gr.Column():
text = gr.Textbox(
label="Prompt", lines=3, placeholder="Enter text here..."
)
model = gr.Radio(
["melody", "medium", "small", "large", "facebook/audiogen-medium"],
[
"facebook/musicgen-melody",
# "musicgen-melody",
"facebook/musicgen-medium",
# "musicgen-medium",
"facebook/musicgen-small",
# "musicgen-small",
"facebook/musicgen-large",
# "musicgen-large",
"facebook/audiogen-medium",
# "audiogen-medium",
],
label="Model",
value="melody",
value="facebook/musicgen-small",
)
melody = gr.Audio(
source="upload",
Expand Down Expand Up @@ -269,6 +294,10 @@ def generation_tab_musicgen():
interactive=True,
step=0.1,
)
use_multi_band_diffusion = gr.Checkbox(
label="Use Multi-Band Diffusion",
value=False,
)
seed, set_old_seed_button, _ = setup_seed_ui_musicgen()

with gr.Column():
Expand All @@ -295,7 +324,18 @@ def generation_tab_musicgen():
outputs=[melody],
)

inputs = [text, melody, model, duration, topk, topp, temperature, cfg_coef, seed]
inputs = [
text,
melody,
model,
duration,
topk,
topp,
temperature,
cfg_coef,
seed,
use_multi_band_diffusion,
]

def update_components(x):
return {
Expand All @@ -308,6 +348,7 @@ def update_components(x):
temperature: x["temperature"],
cfg_coef: x["cfg_coef"],
seed: x["seed"],
use_multi_band_diffusion: x["use_multi_band_diffusion"],
}

musicgen_atom.change(
Expand All @@ -317,7 +358,7 @@ def update_components(x):
)

def update_json(
text, _melody, model, duration, topk, topp, temperature, cfg_coef, seed
text, _melody, model, duration, topk, topp, temperature, cfg_coef, seed, use_multi_band_diffusion
):
return {
"text": text,
Expand All @@ -329,6 +370,7 @@ def update_json(
"temperature": float(temperature),
"cfg_coef": float(cfg_coef),
"seed": int(seed),
"use_multi_band_diffusion": bool(use_multi_band_diffusion),
}

seed_cache = gr.State() # type: ignore
Expand Down