Skip to content

Commit

Permalink
fix: various fixes and enhancements (metavoiceio#46)
Browse files Browse the repository at this point in the history
* fix: various fixes and enhancements

* Update sample.py

Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>

* Update serving.py

Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>

* move requirements

---------

Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>
  • Loading branch information
Piotr Sokólski authored Feb 12, 2024
1 parent c82b062 commit 43f97a0
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 74 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/
rm -rf ffmpeg-git-*

pip install -r requirements.txt

# Works only on lasest NVidia GPUs. If you have a different GPU, do not install this.
pip install flash-attn

pip install -e .
```

Expand Down
102 changes: 74 additions & 28 deletions fam/llm/layers/attn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import math
import warnings

import torch
import torch.nn as nn
from flash_attn import ( # type: ignore
flash_attn_func,
flash_attn_qkvpacked_func,
flash_attn_with_kvcache,
)

try:
from flash_attn import ( # type: ignore
flash_attn_func,
flash_attn_qkvpacked_func,
flash_attn_with_kvcache,
)
except ImportError:
warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn")
from torch.nn import functional as F


class SelfAttention(nn.Module):
Expand Down Expand Up @@ -72,9 +80,7 @@ def _initialize_parameters(self, config):
self.dropout = config.dropout
self.causal = config.causal
self.attn_kernel_type = config.attn_kernel_type

if self.attn_kernel_type == "hand":
self.attn_dropout = nn.Dropout(config.dropout)
self.attn_dropout = nn.Dropout(config.dropout)

self.kv_cache_enabled = False

Expand Down Expand Up @@ -123,6 +129,40 @@ def _update_kv_cache(self, q, k, v):

return k, v

def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs Flash Attention 2.0 CUDA kernel based attention.
Args:
c_x: The input tensor.
Returns:
The output tensor.
"""
if self.kv_cache_enabled:
q, k, v = c_x.split(1, dim=2)
q = q.squeeze(2)
k = k.squeeze(2)
v = v.squeeze(2)

k, v = self._update_kv_cache(q, k, v)

y = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout if self.training else 0,
softmax_scale=None,
causal=self.causal,
)
else:
# efficient attention using Flash Attention 2.0 CUDA kernels
y = flash_attn_qkvpacked_func(
c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal
) # outputs (B, T, nh, hs)

return y

def _fd_attention(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs Flash decoding based attention.
Expand Down Expand Up @@ -189,40 +229,44 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=self.causal and (self.kv_cache_enabled is not True),
is_causal=self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0),
).transpose(
1, 2
) # (B, nh, T, hs) -> (B, T, nh, hs)

return y

def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor:
def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs Flash Attention 2.0 CUDA kernel based attention.
Performs vanilla attention.
Args:
c_x: The input tensor.
Returns:
The output tensor.
"""
q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, nh, hs)
q = q.squeeze(2) # (B, T, nh, hs)
k = k.squeeze(2) # (B, T, nh, hs)
v = v.squeeze(2) # (B, T, nh, hs)

if self.kv_cache_enabled:
q, k, v = c_x.split(1, dim=2)
q = q.squeeze(2)
k = k.squeeze(2)
v = v.squeeze(2)
k, v = self._update_kv_cache(q, k, v)
y = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout if self.training else 0,
softmax_scale=None,
causal=self.causal,
)
else:
# efficient attention using Flash Attention 2.0 CUDA kernels
y = flash_attn_qkvpacked_func(
c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal
) # outputs (B, T, nh, hs)

q = q.transpose(1, 2) # (B, nh, T, hs)
k = k.transpose(1, 2) # (B, nh, T, hs)
v = v.transpose(1, 2) # (B, nh, T, hs)
att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
if self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0):
att = att.masked_fill(
torch.triu(torch.ones_like(att, dtype=torch.bool), diagonal=1), float("-inf")
) # (B, nh, T, T)
att = F.softmax(att, dim=-1) # (B, nh, T, T)
att = self.attn_dropout(att) # (B, nh, T, T)
y = att @ v # (B, nh, T, hs)
y = y.transpose(1, 2) # (B, T, nh, hs)

return y

def forward(self, x):
Expand All @@ -247,6 +291,8 @@ def forward(self, x):
y = self._fd_attention(c_x)
elif self.attn_kernel_type == "torch_attn":
y = self._torch_attn(c_x)
elif self.attn_kernel_type == "hand":
y = self._vanilla_attn(c_x)
else:
raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")

Expand Down
85 changes: 62 additions & 23 deletions fam/llm/mixins/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _sample_next_token(
temperature: float,
top_k: Optional[int],
top_p: Optional[float],
guidance_scale: Optional[float],
guidance_scale: Optional[Tuple[float, float]],
) -> torch.Tensor:
"""
Predict the next token in the sequence.
Expand Down Expand Up @@ -87,14 +87,22 @@ def _sample_next_token(
) # list with len num_hierarchies of (b,1,vocab_size) tensors

if guidance_scale is not None:
assert idx_cond.shape[0] % 2 == 0
assert list_logits[0].shape[0] % 2 == 0
spkemb_guidance_scale, prompt_guidance_scale = guidance_scale
assert spkemb_guidance_scale >= 1
assert prompt_guidance_scale >= 1
base_scale = spkemb_guidance_scale + prompt_guidance_scale - 1

for i, logits in enumerate(list_logits):
logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0)
list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond

assert list_logits[0].shape[0] == idx_cond.shape[0] // 2
if prompt_guidance_scale > 1:
logits_cond, logits_uncond_spkemb, logits_uncond_prompt = logits.split(logits.shape[0] // 3, dim=0)
else:
logits_cond, logits_uncond_spkemb = logits.split(logits.shape[0] // 2, dim=0)
logits_uncond_prompt = 0
list_logits[i] = (
(base_scale) * logits_cond
+ (1 - spkemb_guidance_scale) * logits_uncond_spkemb
+ (1 - prompt_guidance_scale) * logits_uncond_prompt
)

# pluck the logits at the final step and scale by desired temperature
list_logits = [
Expand Down Expand Up @@ -178,7 +186,9 @@ def _sample_batch(
top_k: Optional[int],
top_p: Optional[float],
speaker_embs: Optional[torch.Tensor],
guidance_scale: Optional[float],
guidance_scale: Optional[Tuple[float, float]],
end_of_audio_token: int,
end_of_text_token: int,
):
"""
Samples a batch of tokens from the model.
Expand All @@ -202,33 +212,54 @@ def _sample_batch(

min_seq_lens = min(seq_lens)
idx = idx[:, :, :min_seq_lens]
idx_out = torch.full(
(idx.shape[0], idx.shape[1], idx.shape[2] + max_new_tokens),
end_of_audio_token,
dtype=idx.dtype,
device=idx.device,
)
idx_out[:, :, :min_seq_lens] = idx
terminated = idx.new_zeros(idx.shape[0], dtype=torch.bool)

if guidance_scale is not None:
_, prompt_guidance_scale = guidance_scale
if speaker_embs is None:
raise Exception("Guidance is only supported for conditional models")

# create speaker embeddings equivalent to the batch size, filling with None
# for second half to do unconditional generation.
speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0])
speaker_embs = (
list(speaker_embs)
+ [None] * (speaker_embs.shape[0])
+ (list(speaker_embs) if prompt_guidance_scale > 1 else [])
)

for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "):
if terminated.all():
break
if (self.kv_cache_enabled is True) and (timestep > min_seq_lens):
idx_input = idx[:, :, -1:]
idx_input = idx_out[:, :, [timestep - 1]]
else:
idx_input = idx
idx_input = idx_out[:, :, :timestep]

if guidance_scale is not None:
_, prompt_guidance_scale = guidance_scale
# TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size.
if timestep == min_seq_lens:
print("[hack!!!!] Guidance is on, so we're doubling batch size!")
print("[hack!!!!] Guidance is on, so we're doubling/tripling batch size!")

# replicate idx in the batch dimension
idx_input = (
idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2])
idx_input.unsqueeze(0)
.repeat(3 if prompt_guidance_scale > 1 else 2, 1, 1, 1)
.reshape(-1, idx_input.shape[1], idx_input.shape[2])
)

# sanity checks
assert idx_input.shape[0] % 2 == 0
if prompt_guidance_scale > 1:
idx_input_uncond = idx_input[idx_input.shape[0] // 3 * 2 :]
idx_input_uncond = idx_input_uncond.view(-1)
# Replace all text tokens with endoftext token
idx_input_uncond[idx_input_uncond > end_of_audio_token] = end_of_text_token

idx_next = self._sample_next_token(
idx=idx_input,
Expand All @@ -247,12 +278,13 @@ def _sample_batch(
orig_input_at_t=input[:, :, timestep],
token_pred_mask_at_t=token_pred_mask[:, [timestep]],
)

idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor
is_endofaudio = (idx_next == end_of_audio_token).any(dim=-1) # shape: b
terminated = terminated | is_endofaudio
idx_next[terminated] = end_of_audio_token
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=2)
idx_out[:, :, timestep] = idx_next

return idx
return idx_out

@torch.no_grad()
def _sort_for_batching(
Expand Down Expand Up @@ -317,7 +349,10 @@ def _causal_sample(
top_p: Optional[float],
speaker_embs: Optional[torch.Tensor],
batch_size: int,
guidance_scale: Optional[float] = None,
guidance_scale: Optional[Tuple[float, float]] = None,
dtype: torch.dtype = torch.bfloat16,
end_of_audio_token: int,
end_of_text_token: int,
) -> torch.Tensor:
"""
Generates a sequence of tokens using causal sampling.
Expand Down Expand Up @@ -354,14 +389,16 @@ def _causal_sample(

kv_batch_size = end_index - start_index
if guidance_scale is not None:
kv_batch_size = 2 * kv_batch_size
if guidance_scale[1] > 1:
kv_batch_size = 3 * kv_batch_size
else:
kv_batch_size = 2 * kv_batch_size

if self.kv_cache_enabled:
print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16")
self.empty_kv_cache(
batch_size=kv_batch_size,
kv_cache_maxlen=self.config.block_size,
dtype=torch.bfloat16,
dtype=dtype,
)

batch_seq_lens = seq_lens[start_index:end_index]
Expand All @@ -379,6 +416,8 @@ def _causal_sample(
top_p=top_p,
speaker_embs=batch_speaker_embs,
guidance_scale=guidance_scale,
end_of_audio_token=end_of_audio_token,
end_of_text_token=end_of_text_token,
)
return_idx[start_index:end_index] = batch_idx

Expand Down
11 changes: 8 additions & 3 deletions fam/llm/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import math
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from typing import Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -350,7 +350,10 @@ def generate(
top_p: Optional[float] = None,
speaker_embs: Optional[torch.Tensor] = None,
batch_size: Optional[int] = None,
guidance_scale: Optional[float] = None,
guidance_scale: Optional[Tuple[float, float]] = None,
dtype: torch.dtype = torch.bfloat16,
end_of_audio_token: int = 99999, # Dummy values will disable early termination / guidance features.
end_of_text_token: int = 99999,
):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete
Expand All @@ -373,6 +376,9 @@ def generate(
speaker_embs=speaker_embs,
batch_size=batch_size,
guidance_scale=guidance_scale,
dtype=dtype,
end_of_audio_token=end_of_audio_token,
end_of_text_token=end_of_text_token,
)

else:
Expand Down Expand Up @@ -400,4 +406,3 @@ def generate(
)
)
return torch.cat(out, dim=0)
return torch.cat(out, dim=0)
Loading

0 comments on commit 43f97a0

Please sign in to comment.