Skip to content

Commit

Permalink
add cross attention layers as well as setup t5 and some conditioning …
Browse files Browse the repository at this point in the history
…logic, for helping researchers explore TTS in this setting
  • Loading branch information
lucidrains committed Nov 11, 2022
1 parent 26dfc80 commit c17ee7d
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 25 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ loss.backward()

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate library
- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate and transformers libraries

- <a href="https://ai.facebook.com/">MetaAI</a> for <a href="https://github.com/facebookresearch/fairseq">Fairseq</a> and the liberal license

Expand All @@ -58,6 +58,7 @@ loss.backward()

- [x] complete CoarseTransformer
- [x] use fairseq vq-wav2vec for embeddings
- [x] add conditioning

- [ ] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [ ] complete full training code for soundstream, taking care of discriminator training
Expand All @@ -69,7 +70,8 @@ loss.backward()
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
- [ ] DRY a little at the end
- [ ] figure out how to suppress logging in fairseq
- [ ] test with speech synthesis for starters, add conditioning + classifier free guidance as well
- [ ] test with speech synthesis for starters
- [ ] add classifier free guidance

## Citations

Expand Down
149 changes: 127 additions & 22 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def ceil_div(numer, denom):
return (numer + denom - 1) // denom

Expand Down Expand Up @@ -471,24 +476,53 @@ class Attention(nn.Module):
def __init__(
self,
dim,
causal = False,
dim_head = 64,
heads = 8
dim_context = None,
heads = 8,
norm_context = False,
num_null_kv = 0
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads

dim_context = default(dim_context, dim)

self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

self.num_null_kv = num_null_kv
self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x, attn_bias = None):
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
b = x.shape[0]

if exists(context):
context = self.context_norm(context)

kv_input = default(context, x)

x = self.norm(x)

q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

if self.num_null_kv > 0:
null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
k = torch.cat((null_k, k), dim = -2)
v = torch.cat((null_v, v), dim = -2)

q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

Expand All @@ -497,11 +531,18 @@ def forward(self, x, attn_bias = None):
sim = einsum('b h i d, b j d -> b h i j', q, k)

if exists(attn_bias):
attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
sim = sim + attn_bias

i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
if exists(mask):
mask = F.pad(mask, (self.num_null_kv, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)

Expand All @@ -518,6 +559,8 @@ def __init__(
*,
dim,
depth,
dim_context = None,
cross_attend = False,
**kwargs
):
super().__init__()
Expand All @@ -527,19 +570,31 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, **kwargs),
Attention(dim = dim, causal = True, **kwargs),
Attention(dim = dim, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
FeedForward(dim = dim)
]))

self.norm = nn.LayerNorm(dim)

def forward(self, x):
def forward(
self,
x,
context = None,
context_mask = None
):
n, device = x.shape[1], x.device

rel_pos_bias = self.rel_pos_bias(n, n, device = device)

for attn, ff in self.layers:
for attn, cross_attn, ff in self.layers:
x = attn(x, attn_bias = rel_pos_bias) + x

if exists(cross_attn):
assert exists(context)

x = cross_attn(x, context = context, mask = context_mask)

x = ff(x) + x

return self.norm(x)
Expand All @@ -552,25 +607,42 @@ def __init__(
*,
num_semantic_tokens,
dim,
t5_name = DEFAULT_T5_NAME,
has_condition = False,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
**kwargs
):
super().__init__()
self.has_condition = has_condition
self.embed_text = partial(t5_encode_text, name = t5_name)

self.start_token = nn.Parameter(torch.randn(dim))

self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)

self.wav2vec = wav2vec
self.transformer = Transformer(dim = dim, **kwargs)
self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
self.to_logits = nn.Linear(dim, num_semantic_tokens)

def forward(
self,
*,
raw_wave = None,
ids = None,
return_loss = False
return_loss = False,
text = None,
text_embed = None
):
device = next(self.parameters()).device

has_text = exists(text) or exists(text_embed)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

assert exists(raw_wave) ^ exists(ids)

if not exists(ids):
Expand All @@ -586,7 +658,7 @@ def forward(

tokens = torch.cat((start_tokens, tokens), dim = 1)

tokens = self.transformer(tokens)
tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
logits = self.to_logits(tokens)

if not return_loss:
Expand All @@ -607,17 +679,22 @@ def __init__(
codebook_size,
num_coarse_quantizers,
dim,
t5_name = DEFAULT_T5_NAME,
has_condition = False,
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
**kwargs
):
super().__init__()
self.has_condition = has_condition
self.embed_text = partial(t5_encode_text, name = t5_name)

self.start_token = nn.Parameter(torch.randn(dim))

self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)

self.wav2vec = wav2vec
self.transformer = Transformer(dim = dim, **kwargs)
self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)

self.codebook_size = codebook_size
self.num_coarse_quantizers = num_coarse_quantizers
Expand All @@ -630,9 +707,19 @@ def forward(
*,
semantic_token_ids,
coarse_token_ids,
text = None,
text_embed = None
):
b, device = semantic_token_ids.shape[0], semantic_token_ids.device

has_text = exists(text) or exists(text_embed)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
Expand All @@ -649,7 +736,7 @@ def forward(

tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1)

tokens = self.transformer(tokens)
tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)

pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]

Expand Down Expand Up @@ -689,16 +776,20 @@ def __init__(
num_fine_quantizers,
codebook_size,
dim,
t5_name = DEFAULT_T5_NAME,
has_condition = False,
**kwargs
):
super().__init__()
self.has_condition = has_condition
self.embed_text = partial(t5_encode_text, name = t5_name)

self.start_token = nn.Parameter(torch.randn(dim))

self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)

self.transformer = Transformer(dim = dim, **kwargs)
self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)

self.codebook_size = codebook_size
self.num_coarse_quantizers = num_coarse_quantizers
Expand All @@ -710,10 +801,20 @@ def __init__(
def forward(
self,
coarse_token_ids,
fine_token_ids
fine_token_ids,
text = None,
text_embed = None
):
device = coarse_token_ids.device

has_text = exists(text) or exists(text_embed)
assert not (self.has_condition ^ has_text)

if not exists(text_embed):
with torch.no_grad():
text_embeds = self.embed_text(text, output_device = device)
text_mask = torch.any(text_embeds != 0, dim = -1)

coarse_token_ids, fine_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, fine_token_ids))

b, n = coarse_token_ids.shape
Expand All @@ -735,7 +836,7 @@ def forward(

tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1)

tokens = self.transformer(tokens)
tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)

pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:]

Expand Down Expand Up @@ -794,7 +895,8 @@ def forward(
raw_wave = None,
coarse_token_ids = None,
fine_token_ids = None,
return_loss = False
return_loss = False,
**kwargs
):
assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

Expand All @@ -815,7 +917,8 @@ def forward(

coarse_logits, fine_logits = self.transformer(
coarse_token_ids = coarse_token_ids,
fine_token_ids = fine_token_ids
fine_token_ids = fine_token_ids,
**kwargs
)

if not return_loss:
Expand Down Expand Up @@ -859,7 +962,8 @@ def forward(
semantic_token_ids = None,
raw_wave = None,
coarse_token_ids = None,
return_loss = False
return_loss = False,
**kwargs
):
assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
assert exists(raw_wave) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
Expand All @@ -886,7 +990,8 @@ def forward(

semantic_logits, coarse_logits = self.transformer(
semantic_token_ids = semantic_token_ids,
coarse_token_ids = coarse_token_ids
coarse_token_ids = coarse_token_ids,
**kwargs
)

if not return_loss:
Expand Down
Loading

0 comments on commit c17ee7d

Please sign in to comment.