diff --git a/README.md b/README.md index ea6aa11..bd4e523 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,10 @@ Generate speech, sound effects, music and beyond. 2. Try to use different random seeds, which can affect the generation quality significantly sometimes. 3. It's best to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'. +# Change Log + +**2023-02-15**: Add audio style transfer. Add more options on generation. + ## Web APP 1. Prepare running environment ```shell @@ -38,10 +42,19 @@ pip3 install audioldm==0.0.6 2. text-to-audio generation ```python # Test run -audioldm -t "A hammer is hitting a wooden surface" +audioldm -t "A hammer is hitting a wooden surface" # The default --mode is "generation" +``` + +3. audio-to-audio style transfer +```python +# Test run +# --file_path is the original audio file for transfer +# -t is the text AudioLDM uses for transfer. +# Please make sure that --file_path exist +audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing" ``` -For more options on guidance scale, batchsize, seed, etc, please run +For more options on guidance scale, batchsize, seed, ddim steps, etc., please run ```shell audioldm -h ``` @@ -59,7 +72,7 @@ Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using - [ ] Add AudioCaps finetuned AudioLDM-S model - [x] Build pip installable package for commandline use - [x] Build Gradio web application -- [ ] Add text-guided style transfer +- [x] Add text-guided style transfer - [ ] Add audio super-resolution - [ ] Add audio inpainting diff --git a/audioldm/__init__.py b/audioldm/__init__.py index b9f2387..37f2303 100644 --- a/audioldm/__init__.py +++ b/audioldm/__init__.py @@ -1,5 +1,5 @@ from .ldm import LatentDiffusion -from .utils import seed_everything, save_wave +from .utils import seed_everything, save_wave, get_time from .pipeline import * import os diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py index e69de29..2ced654 100644 --- a/audioldm/audio/__init__.py +++ b/audioldm/audio/__init__.py @@ -0,0 +1,2 @@ +from .tools import wav_to_fbank +from .stft import TacotronSTFT \ No newline at end of file diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py index 7aca95c..ddf4c9c 100644 --- a/audioldm/audio/tools.py +++ b/audioldm/audio/tools.py @@ -1,6 +1,6 @@ import torch import numpy as np - +import torchaudio def get_mel_from_wav(audio, _stft): audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) @@ -13,21 +13,62 @@ def get_mel_from_wav(audio, _stft): energy = torch.squeeze(energy, 0).numpy().astype(np.float32) return melspec, log_magnitudes_stft, energy +def _pad_spec(fbank, target_length=1024): + n_frames = fbank.shape[0] + p = target_length - n_frames + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[0:target_length, :] + + if(fbank.size(-1) % 2 != 0): + fbank = fbank[..., :-1] + + return fbank + +def pad_wav(waveform, segment_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + if(waveform_length == segment_length): + return waveform + elif(waveform_length > segment_length): + return waveform[: segment_length] + elif(waveform_length < segment_length): + temp_wav = np.zeros((1, segment_length)) + temp_wav[:, :waveform_length] = waveform + return temp_wav -# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): -# mel = torch.stack([mel]) -# mel_decompress = _stft.spectral_de_normalize(mel) -# mel_decompress = mel_decompress.transpose(1, 2).data.cpu() -# spec_from_mel_scaling = 1000 -# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) -# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) -# spec_from_mel = spec_from_mel * spec_from_mel_scaling +def normalize_wav(waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 + +def read_wav_file(filename, segment_length): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) + waveform = waveform.numpy()[0,...] + waveform = normalize_wav(waveform) + waveform = waveform[None,...] + waveform = pad_wav(waveform, segment_length) + return waveform + +def wav_to_fbank(filename, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + # mixup + waveform = read_wav_file(filename, target_length * 160) # hop size is 160 + waveform = waveform[0,...] + waveform = torch.FloatTensor(waveform) -# audio = griffin_lim( -# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters -# ) + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + + fbank = torch.FloatTensor(fbank.T) + log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(log_magnitudes_stft, target_length) -# audio = audio.squeeze() -# audio = audio.cpu().numpy() -# audio_path = out_filename -# write(audio_path, _stft.sampling_rate, audio) + return fbank, log_magnitudes_stft, waveform + \ No newline at end of file diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py index 3a325b9..f5a31ae 100644 --- a/audioldm/pipeline.py +++ b/audioldm/pipeline.py @@ -3,9 +3,14 @@ import argparse import yaml import torch +from torch import autocast +from tqdm import tqdm, trange from audioldm import LatentDiffusion, seed_everything from audioldm.utils import default_audioldm_config +from audioldm.audio import wav_to_fbank, TacotronSTFT +from audioldm.latent_diffusion.ddim import DDIMSampler +from einops import repeat import time @@ -71,11 +76,12 @@ def text_to_audio( latent_diffusion, text, seed=42, + ddim_steps=200, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, - config=None, + config=None ): seed_everything(int(seed)) batch = make_batch_for_text_to_audio(text, batchsize=batchsize) @@ -85,7 +91,75 @@ def text_to_audio( waveform = latent_diffusion.generate_sample( [batch], unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, n_candidate_gen_per_text=n_candidate_gen_per_text, duration=duration, ) return waveform + +def style_transfer( + latent_diffusion, + text, + original_audio_file_path, + transfer_strength, + seed=42, + duration=10, + batchsize=1, + guidance_scale=2.5, + ddim_steps=200, + config=None +): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + + seed_everything(int(seed)) + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion.cond_stage_model.embed_mode = "text" + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + mel, _, _ = wav_to_fbank(original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT) + mel = mel.unsqueeze(0).unsqueeze(0).to(device) + mel = repeat(mel, '1 ... -> b ...', b=batchsize) + init_latent = latent_diffusion.get_first_stage_encoding(latent_diffusion.encode_first_stage(mel)) # move to latent space, encode and sample + + sampler = DDIMSampler(latent_diffusion) + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) + + t_enc = int(transfer_strength * ddim_steps) + prompts = text + + with torch.no_grad(): + with autocast("cuda"): + with latent_diffusion.ema_scope(): + uc = None + if guidance_scale != 1.0: + uc = latent_diffusion.cond_stage_model.get_unconditional_condition(batchsize) + + c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) + + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batchsize).to(device)) + + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=guidance_scale, unconditional_conditioning=uc) + + x_samples = latent_diffusion.decode_first_stage(samples) + + waveform = latent_diffusion.first_stage_model.decode_to_waveform(x_samples) + + return waveform diff --git a/audioldm/utils.py b/audioldm/utils.py index 1ed0e62..fcc8516 100644 --- a/audioldm/utils.py +++ b/audioldm/utils.py @@ -4,7 +4,11 @@ import os import soundfile as sf +import time +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) def seed_everything(seed): import random, os @@ -35,6 +39,7 @@ def save_wave(waveform, savepath, name="outwav"): i, ), ) + print("Save audio to %s" % path) sf.write(path, waveform[i, 0], samplerate=16000) @@ -81,6 +86,28 @@ def default_audioldm_config(): "name": "default", "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", }, + "preprocessing": { + "audio": { + "sampling_rate": 16000, + "max_wav_value": 32768 + }, + "stft": { + "filter_length": 1024, + "hop_length": 160, + "win_length": 1024 + }, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 8000, + "freqm": 0, + "timem": 0, + "blur": False, + "mean": -4.63, + "std": 2.74, + "target_length": 1024 + } + }, "model": { "device": "cuda", "target": "audioldm.pipline.LatentDiffusion", diff --git a/bin/audioldm b/bin/audioldm index 61edee4..d576729 100644 --- a/bin/audioldm +++ b/bin/audioldm @@ -1,11 +1,20 @@ #!/usr/bin/python3 import os -from audioldm import text_to_audio, build_model, save_wave - +from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time +import time import argparse parser = argparse.ArgumentParser() +parser.add_argument( + "--mode", + type=str, + required=False, + default="generation", + help="generation: text-to-audio generation; transfer: style transfer", + choices=["generation", "transfer"] +) + parser.add_argument( "-t", "--text", @@ -15,6 +24,23 @@ parser.add_argument( help="Text prompt to the model for audio generation", ) +parser.add_argument( + "-f", + "--file_path", + type=str, + required=False, + default=None, + help="Original audio file for style transfer", +) + +parser.add_argument( + "--transfer_strength", + type=float, + required=False, + default=0.5, + help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", +) + parser.add_argument( "-s", "--save_path", @@ -45,6 +71,15 @@ parser.add_argument( help="Generate how many samples at the same time", ) +parser.add_argument( + "--ddim_steps", + type=int, + required=False, + default=200, + help="The sampling step for DDIM", +) + + parser.add_argument( "-gs", "--guidance_scale", @@ -81,10 +116,11 @@ parser.add_argument( ) args = parser.parse_args() - assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" +save_path = os.path.join(args.save_path, args.mode) +if(args.file_path is not None): + save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) -save_path = args.save_path text = args.text random_seed = args.seed duration = args.duration @@ -93,14 +129,32 @@ n_candidate_gen_per_text = args.n_candidate_gen_per_text os.makedirs(save_path, exist_ok=True) audioldm = build_model(ckpt_path=args.ckpt_path) -waveform = text_to_audio( - audioldm, - text, - random_seed, - duration=duration, - guidance_scale=guidance_scale, - n_candidate_gen_per_text=n_candidate_gen_per_text, - batchsize=args.batchsize, -) -save_wave(waveform, save_path, name=text) +if(args.mode == "generation"): + waveform = text_to_audio( + audioldm, + text, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=args.batchsize, + ) +elif(args.mode == "transfer"): + assert args.file_path is not None + assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path + waveform = style_transfer( + audioldm, + text, + args.file_path, + args.transfer_strength, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + batchsize=args.batchsize, + ) + waveform = waveform[:,None,:] + +save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) diff --git a/setup.py b/setup.py index 74d064c..18627e6 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ EMAIL = "haoheliu@gmail.com" AUTHOR = "Haohe Liu" REQUIRES_PYTHON = ">=3.7.0" -VERSION = "0.0.6" +VERSION = "0.0.7" # What packages are required for this module to be executed? REQUIRED = [ diff --git a/trumpet.wav b/trumpet.wav new file mode 100644 index 0000000..efe4e29 Binary files /dev/null and b/trumpet.wav differ diff --git a/trumpet/15_02_2023_11_29_36_Children Singing_0.wav b/trumpet/15_02_2023_11_29_36_Children Singing_0.wav new file mode 100644 index 0000000..52737eb Binary files /dev/null and b/trumpet/15_02_2023_11_29_36_Children Singing_0.wav differ diff --git a/trumpet/15_02_2023_11_31_37_Children Singing_0.wav b/trumpet/15_02_2023_11_31_37_Children Singing_0.wav new file mode 100644 index 0000000..3f8bc91 Binary files /dev/null and b/trumpet/15_02_2023_11_31_37_Children Singing_0.wav differ diff --git a/trumpet/15_02_2023_11_33_27_Children Singing_0.wav b/trumpet/15_02_2023_11_33_27_Children Singing_0.wav new file mode 100644 index 0000000..b98e8d5 Binary files /dev/null and b/trumpet/15_02_2023_11_33_27_Children Singing_0.wav differ