Skip to content

Commit

Permalink
make sure unconditional synthesis can still work, add ability to resa…
Browse files Browse the repository at this point in the history
…mple input wave on the fly given input sampling frequencies is supplied
  • Loading branch information
lucidrains committed Nov 15, 2022
1 parent 2725ae8 commit 0290273
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 19 deletions.
40 changes: 26 additions & 14 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

from torchaudio.functional import resample

# helper functions

def exists(val):
Expand Down Expand Up @@ -295,9 +297,12 @@ def __init__(
adversarial_loss_weight = 1.,
feature_loss_weight = 100,
quantize_dropout = True,
quantize_dropout_cutoff_index = 0
quantize_dropout_cutoff_index = 0,
target_sample_khz = 24000
):
super().__init__()
self.target_sample_khz = target_sample_khz # for resampling on the fly

self.single_channel = input_channels == 1
self.strides = strides

Expand Down Expand Up @@ -363,8 +368,12 @@ def forward(
return_encoded = False,
return_discr_loss = False,
return_discr_losses_separately = False,
return_recons_only = False
return_recons_only = False,
input_sample_khz = None
):
if exists(input_sample_khz):
x = resample(x, input_sample_khz, self.target_sample_khz)

if x.ndim == 2:
x = rearrange(x, 'b n -> b 1 n')

Expand Down Expand Up @@ -699,7 +708,7 @@ def forward(
ids = None,
return_loss = False,
text = None,
text_embed = None,
text_embeds = None,
cond_drop_prob = None
):
device = next(self.parameters()).device
Expand All @@ -717,17 +726,18 @@ def forward(
if self.unique_consecutive:
ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

has_text = exists(text) or exists(text_embed)
has_text = exists(text) or exists(text_embeds)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
text_mask = None
if not exists(text_embeds) and exists(text):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

if cond_drop_prob > 0:
if exists(text_mask) and cond_drop_prob > 0:
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

Expand Down Expand Up @@ -798,22 +808,23 @@ def forward(
coarse_token_ids,
self_attn_mask = None,
text = None,
text_embed = None,
text_embeds = None,
cond_drop_prob = None
):
b, device = semantic_token_ids.shape[0], semantic_token_ids.device

has_text = exists(text) or exists(text_embed)
has_text = exists(text) or exists(text_embeds)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
text_mask = None
if not exists(text_embeds) and exists(text):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

if cond_drop_prob > 0:
if exists(text_mask) and cond_drop_prob > 0:
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

Expand Down Expand Up @@ -906,21 +917,22 @@ def forward(
coarse_token_ids,
fine_token_ids,
text = None,
text_embed = None,
text_embeds = None,
cond_drop_prob = None
):
b, device = coarse_token_ids.shape[0], coarse_token_ids.device
has_text = exists(text) or exists(text_embed)
has_text = exists(text) or exists(text_embeds)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
text_mask = None
if not exists(text_embeds) and exists(text):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

if cond_drop_prob > 0:
if exists(text_mask) and cond_drop_prob > 0:
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

Expand Down
20 changes: 18 additions & 2 deletions audiolm_pytorch/hubert_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
import joblib
import fairseq

from torchaudio.functional import resample

def exists(val):
return val is not None

class HubertWithKmeans(nn.Module):
def __init__(
self,
checkpoint_path,
kmeans_path
kmeans_path,
target_sample_khz = 50000
):
super().__init__()
self.target_sample_khz = target_sample_khz

model_path = Path(checkpoint_path)
kmeans_path = Path(kmeans_path)

Expand All @@ -39,9 +47,17 @@ def codebook_size(self):
return self.kmeans.n_clusters

@torch.no_grad()
def forward(self, wav_input, flatten = True):
def forward(
self,
wav_input,
flatten = True,
input_sample_khz = None
):
device = wav_input.device

if exists(input_sample_khz):
wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

embed = self.model(wav_input, features_only = True)
embed, packed_shape = pack([embed['x']], '* d')

Expand Down
20 changes: 18 additions & 2 deletions audiolm_pytorch/vq_wav2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@

import fairseq

from torchaudio.functional import resample

def exists(val):
return val is not None

class FairseqVQWav2Vec(nn.Module):
def __init__(
self,
checkpoint_path
checkpoint_path,
target_sample_khz = 24000
):
super().__init__()
self.target_sample_khz = target_sample_khz

path = Path(checkpoint_path)
assert path.exists(), f'path {checkpoint_path} does not exist'

Expand All @@ -31,7 +39,15 @@ def codebook_size(self):
return self.model.vector_quantizer.embedding.shape[0]

@torch.no_grad()
def forward(self, wav_input, flatten = True):
def forward(
self,
wav_input,
flatten = True,
input_sample_khz = None
):
if exists(input_sample_khz):
wav_input = resample(wav_input, input_sample_khz, self.target_sample_khz)

embed = self.model.feature_extractor(wav_input)
_, codebook_indices = self.model.vector_quantizer.forward_idx(embed)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.20',
version = '0.0.21',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0290273

Please sign in to comment.