Skip to content

Commit

Permalink
semantic token ids will have variable lengths because of unique conse…
Browse files Browse the repository at this point in the history
…cutive, so eos token must be manually selected and then used to predict the first coarse token, in the coarse transformer
  • Loading branch information
lucidrains committed Nov 17, 2022
1 parent c6bcd11 commit 70f02c5
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 74 deletions.
160 changes: 87 additions & 73 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def generate(
cond_scale = 3,
filter_thres = 0.9,
temperature = 1.,
include_eos_in_output = True, # if doing hierarchical sampling, eos must be kept for an easy time
**kwargs
):
device = self.device
Expand Down Expand Up @@ -422,7 +423,7 @@ def generate(

last_logit_indices += 1

output = mask_out_after_eos_id(output, self.pad_id, include_eos = False)
output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output)
return output

def forward_with_cond_scale(
Expand Down Expand Up @@ -612,7 +613,17 @@ def forward(

tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)

pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]
pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):]

# get the eos token from predicted semantic tokens, and use that to predict the first coarse token

semantic_eos = semantic_token_ids == self.semantic_eos_id
pred_semantic_eos_tokens = tokens[:, 1:(semantic_seq_len + 1)][semantic_eos]

pred_coarse_tokens = torch.cat((
rearrange(pred_semantic_eos_tokens, 'b d -> b 1 d'),
pred_coarse_tokens),
dim = 1)

# semantic logits

Expand Down Expand Up @@ -789,74 +800,6 @@ def forward(

# training wrappers

class FineTransformerWrapper(nn.Module):
def __init__(
self,
*,
transformer: FineTransformer,
soundstream: Optional[SoundStream] = None,
num_coarse_quantize = 3
):
super().__init__()
self.soundstream = soundstream
self.transformer = transformer

assert num_coarse_quantize > 0
self.num_coarse_quantize = num_coarse_quantize

def forward(
self,
*,
raw_wave = None,
coarse_token_ids = None,
fine_token_ids = None,
return_loss = False,
**kwargs
):
assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

if exists(raw_wave):
assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

with torch.no_grad():
self.soundstream.eval()
_, indices, _ = self.soundstream(raw_wave, return_encoded = True)
coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)

if return_loss:
coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
fine_token_ids = fine_token_ids[:, :-1]

coarse_logits, fine_logits = self.transformer(
coarse_token_ids = coarse_token_ids,
fine_token_ids = fine_token_ids,
**kwargs
)

if not return_loss:
return coarse_logits, fine_logits

coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

coarse_loss = F.cross_entropy(
coarse_logits,
coarse_labels
)

fine_loss = F.cross_entropy(
fine_logits,
fine_labels
)

return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)

class CoarseTransformerWrapper(nn.Module):
def __init__(
Expand Down Expand Up @@ -905,11 +848,12 @@ def forward(
_, indices, _ = self.soundstream(raw_wave, return_encoded = True)
coarse_token_ids, _ = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')

coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)
semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.semantic_eos_id)
if self.training:
semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.semantic_eos_id)
coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)

if self.unique_consecutive:
semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value = self.pad_id)
Expand Down Expand Up @@ -954,6 +898,76 @@ def forward(

return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits)

class FineTransformerWrapper(nn.Module):
def __init__(
self,
*,
transformer: FineTransformer,
soundstream: Optional[SoundStream] = None,
num_coarse_quantize = 3
):
super().__init__()
self.soundstream = soundstream
self.transformer = transformer

assert num_coarse_quantize > 0
self.num_coarse_quantize = num_coarse_quantize

def forward(
self,
*,
raw_wave = None,
coarse_token_ids = None,
fine_token_ids = None,
return_loss = False,
**kwargs
):
assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

if exists(raw_wave):
assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

with torch.no_grad():
self.soundstream.eval()
_, indices, _ = self.soundstream(raw_wave, return_encoded = True)
coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

if self.training:
coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)

if return_loss:
coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
fine_token_ids = fine_token_ids[:, :-1]

coarse_logits, fine_logits = self.transformer(
coarse_token_ids = coarse_token_ids,
fine_token_ids = fine_token_ids,
**kwargs
)

if not return_loss:
return coarse_logits, fine_logits

coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

coarse_loss = F.cross_entropy(
coarse_logits,
coarse_labels
)

fine_loss = F.cross_entropy(
fine_logits,
fine_labels
)

return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)

# audio LM

class AudioLM(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.31',
version = '0.0.32',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 70f02c5

Please sign in to comment.