Skip to content

Commit

Permalink
add style transfer; Update pip package
Browse files Browse the repository at this point in the history
  • Loading branch information
haoheliu committed Feb 15, 2023
1 parent 2a5abdf commit 4aaa6e6
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 36 deletions.
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Generate speech, sound effects, music and beyond.
2. Try to use different random seeds, which can affect the generation quality significantly sometimes.
3. It's best to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.

# Change Log

**2023-02-15**: Add audio style transfer. Add more options on generation.

## Web APP
1. Prepare running environment
```shell
Expand All @@ -38,10 +42,19 @@ pip3 install audioldm==0.0.6
2. text-to-audio generation
```python
# Test run
audioldm -t "A hammer is hitting a wooden surface"
audioldm -t "A hammer is hitting a wooden surface" # The default --mode is "generation"
```

3. audio-to-audio style transfer
```python
# Test run
# --file_path is the original audio file for transfer
# -t is the text AudioLDM uses for transfer.
# Please make sure that --file_path exist
audioldm --mode "transfer" --file_path trumpet.wav -t "Children Singing"
```

For more options on guidance scale, batchsize, seed, etc, please run
For more options on guidance scale, batchsize, seed, ddim steps, etc., please run
```shell
audioldm -h
```
Expand All @@ -59,7 +72,7 @@ Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using
- [ ] Add AudioCaps finetuned AudioLDM-S model
- [x] Build pip installable package for commandline use
- [x] Build Gradio web application
- [ ] Add text-guided style transfer
- [x] Add text-guided style transfer
- [ ] Add audio super-resolution
- [ ] Add audio inpainting

Expand Down
2 changes: 1 addition & 1 deletion audioldm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .ldm import LatentDiffusion
from .utils import seed_everything, save_wave
from .utils import seed_everything, save_wave, get_time
from .pipeline import *

import os
Expand Down
2 changes: 2 additions & 0 deletions audioldm/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .tools import wav_to_fbank
from .stft import TacotronSTFT
73 changes: 57 additions & 16 deletions audioldm/audio/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np

import torchaudio

def get_mel_from_wav(audio, _stft):
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
Expand All @@ -13,21 +13,62 @@ def get_mel_from_wav(audio, _stft):
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
return melspec, log_magnitudes_stft, energy

def _pad_spec(fbank, target_length=1024):
n_frames = fbank.shape[0]
p = target_length - n_frames
# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]

if(fbank.size(-1) % 2 != 0):
fbank = fbank[..., :-1]

return fbank

def pad_wav(waveform, segment_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
if(waveform_length == segment_length):
return waveform
elif(waveform_length > segment_length):
return waveform[: segment_length]
elif(waveform_length < segment_length):
temp_wav = np.zeros((1, segment_length))
temp_wav[:, :waveform_length] = waveform
return temp_wav

# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
# mel = torch.stack([mel])
# mel_decompress = _stft.spectral_de_normalize(mel)
# mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
# spec_from_mel_scaling = 1000
# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
# spec_from_mel = spec_from_mel * spec_from_mel_scaling
def normalize_wav(waveform):
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
return waveform * 0.5

def read_wav_file(filename, segment_length):
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
waveform, sr = torchaudio.load(filename) # Faster!!!
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
waveform = waveform.numpy()[0,...]
waveform = normalize_wav(waveform)
waveform = waveform[None,...]
waveform = pad_wav(waveform, segment_length)
return waveform

def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
assert fn_STFT is not None

# mixup
waveform = read_wav_file(filename, target_length * 160) # hop size is 160
waveform = waveform[0,...]
waveform = torch.FloatTensor(waveform)

# audio = griffin_lim(
# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
# )
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)

fbank = torch.FloatTensor(fbank.T)
log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)

fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(log_magnitudes_stft, target_length)

# audio = audio.squeeze()
# audio = audio.cpu().numpy()
# audio_path = out_filename
# write(audio_path, _stft.sampling_rate, audio)
return fbank, log_magnitudes_stft, waveform

76 changes: 75 additions & 1 deletion audioldm/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import argparse
import yaml
import torch
from torch import autocast
from tqdm import tqdm, trange

from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config
from audioldm.audio import wav_to_fbank, TacotronSTFT
from audioldm.latent_diffusion.ddim import DDIMSampler
from einops import repeat

import time

Expand Down Expand Up @@ -71,11 +76,12 @@ def text_to_audio(
latent_diffusion,
text,
seed=42,
ddim_steps=200,
duration=10,
batchsize=1,
guidance_scale=2.5,
n_candidate_gen_per_text=3,
config=None,
config=None
):
seed_everything(int(seed))
batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
Expand All @@ -85,7 +91,75 @@ def text_to_audio(
waveform = latent_diffusion.generate_sample(
[batch],
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
duration=duration,
)
return waveform

def style_transfer(
latent_diffusion,
text,
original_audio_file_path,
transfer_strength,
seed=42,
duration=10,
batchsize=1,
guidance_scale=2.5,
ddim_steps=200,
config=None
):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")

if config is not None:
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config()

seed_everything(int(seed))
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
latent_diffusion.cond_stage_model.embed_mode = "text"

fn_STFT = TacotronSTFT(
config["preprocessing"]["stft"]["filter_length"],
config["preprocessing"]["stft"]["hop_length"],
config["preprocessing"]["stft"]["win_length"],
config["preprocessing"]["mel"]["n_mel_channels"],
config["preprocessing"]["audio"]["sampling_rate"],
config["preprocessing"]["mel"]["mel_fmin"],
config["preprocessing"]["mel"]["mel_fmax"],
)

mel, _, _ = wav_to_fbank(original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
mel = mel.unsqueeze(0).unsqueeze(0).to(device)
mel = repeat(mel, '1 ... -> b ...', b=batchsize)
init_latent = latent_diffusion.get_first_stage_encoding(latent_diffusion.encode_first_stage(mel)) # move to latent space, encode and sample

sampler = DDIMSampler(latent_diffusion)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False)

t_enc = int(transfer_strength * ddim_steps)
prompts = text

with torch.no_grad():
with autocast("cuda"):
with latent_diffusion.ema_scope():
uc = None
if guidance_scale != 1.0:
uc = latent_diffusion.cond_stage_model.get_unconditional_condition(batchsize)

c = latent_diffusion.get_learned_conditioning([prompts] * batchsize)

z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batchsize).to(device))

samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=guidance_scale, unconditional_conditioning=uc)

x_samples = latent_diffusion.decode_first_stage(samples)

waveform = latent_diffusion.first_stage_model.decode_to_waveform(x_samples)

return waveform
27 changes: 27 additions & 0 deletions audioldm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import os
import soundfile as sf
import time

def get_time():
t = time.localtime()
return time.strftime("%d_%m_%Y_%H_%M_%S", t)

def seed_everything(seed):
import random, os
Expand Down Expand Up @@ -35,6 +39,7 @@ def save_wave(waveform, savepath, name="outwav"):
i,
),
)
print("Save audio to %s" % path)
sf.write(path, waveform[i, 0], samplerate=16000)


Expand Down Expand Up @@ -81,6 +86,28 @@ def default_audioldm_config():
"name": "default",
"root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
},
"preprocessing": {
"audio": {
"sampling_rate": 16000,
"max_wav_value": 32768
},
"stft": {
"filter_length": 1024,
"hop_length": 160,
"win_length": 1024
},
"mel": {
"n_mel_channels": 64,
"mel_fmin": 0,
"mel_fmax": 8000,
"freqm": 0,
"timem": 0,
"blur": False,
"mean": -4.63,
"std": 2.74,
"target_length": 1024
}
},
"model": {
"device": "cuda",
"target": "audioldm.pipline.LatentDiffusion",
Expand Down
Loading

0 comments on commit 4aaa6e6

Please sign in to comment.