From 43f97a0c93b2d3ebe1eaeba242d56a2c46f7e982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Sok=C3=B3lski?= Date: Mon, 12 Feb 2024 18:33:32 +0000 Subject: [PATCH] fix: various fixes and enhancements (#46) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: various fixes and enhancements * Update sample.py Signed-off-by: Piotr Sokólski * Update serving.py Signed-off-by: Piotr Sokólski * move requirements --------- Signed-off-by: Piotr Sokólski --- README.md | 4 ++ fam/llm/layers/attn.py | 102 ++++++++++++++++++++++++++++----------- fam/llm/mixins/causal.py | 85 +++++++++++++++++++++++--------- fam/llm/model.py | 11 +++-- fam/llm/sample.py | 26 +++++----- fam/llm/serving.py | 13 +++-- requirements.txt | 1 - 7 files changed, 168 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 8d8f285..440d883 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,10 @@ sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ rm -rf ffmpeg-git-* pip install -r requirements.txt + +# Works only on lasest NVidia GPUs. If you have a different GPU, do not install this. +pip install flash-attn + pip install -e . ``` diff --git a/fam/llm/layers/attn.py b/fam/llm/layers/attn.py index 47358c7..a90211e 100644 --- a/fam/llm/layers/attn.py +++ b/fam/llm/layers/attn.py @@ -1,10 +1,18 @@ +import math +import warnings + import torch import torch.nn as nn -from flash_attn import ( # type: ignore - flash_attn_func, - flash_attn_qkvpacked_func, - flash_attn_with_kvcache, -) + +try: + from flash_attn import ( # type: ignore + flash_attn_func, + flash_attn_qkvpacked_func, + flash_attn_with_kvcache, + ) +except ImportError: + warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn") +from torch.nn import functional as F class SelfAttention(nn.Module): @@ -72,9 +80,7 @@ def _initialize_parameters(self, config): self.dropout = config.dropout self.causal = config.causal self.attn_kernel_type = config.attn_kernel_type - - if self.attn_kernel_type == "hand": - self.attn_dropout = nn.Dropout(config.dropout) + self.attn_dropout = nn.Dropout(config.dropout) self.kv_cache_enabled = False @@ -123,6 +129,40 @@ def _update_kv_cache(self, q, k, v): return k, v + def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor: + """ + Performs Flash Attention 2.0 CUDA kernel based attention. + + Args: + c_x: The input tensor. + + Returns: + The output tensor. + """ + if self.kv_cache_enabled: + q, k, v = c_x.split(1, dim=2) + q = q.squeeze(2) + k = k.squeeze(2) + v = v.squeeze(2) + + k, v = self._update_kv_cache(q, k, v) + + y = flash_attn_func( + q, + k, + v, + dropout_p=self.dropout if self.training else 0, + softmax_scale=None, + causal=self.causal, + ) + else: + # efficient attention using Flash Attention 2.0 CUDA kernels + y = flash_attn_qkvpacked_func( + c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal + ) # outputs (B, T, nh, hs) + + return y + def _fd_attention(self, c_x: torch.Tensor) -> torch.Tensor: """ Performs Flash decoding based attention. @@ -189,40 +229,44 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: v, attn_mask=None, dropout_p=self.dropout if self.training else 0, - is_causal=self.causal and (self.kv_cache_enabled is not True), + is_causal=self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0), ).transpose( 1, 2 ) # (B, nh, T, hs) -> (B, T, nh, hs) return y - def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor: + def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor: """ - Performs Flash Attention 2.0 CUDA kernel based attention. + Performs vanilla attention. + Args: c_x: The input tensor. + Returns: The output tensor. """ + q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, nh, hs) + q = q.squeeze(2) # (B, T, nh, hs) + k = k.squeeze(2) # (B, T, nh, hs) + v = v.squeeze(2) # (B, T, nh, hs) + if self.kv_cache_enabled: - q, k, v = c_x.split(1, dim=2) - q = q.squeeze(2) - k = k.squeeze(2) - v = v.squeeze(2) k, v = self._update_kv_cache(q, k, v) - y = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout if self.training else 0, - softmax_scale=None, - causal=self.causal, - ) - else: - # efficient attention using Flash Attention 2.0 CUDA kernels - y = flash_attn_qkvpacked_func( - c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal - ) # outputs (B, T, nh, hs) + + q = q.transpose(1, 2) # (B, nh, T, hs) + k = k.transpose(1, 2) # (B, nh, T, hs) + v = v.transpose(1, 2) # (B, nh, T, hs) + att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) + if self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0): + att = att.masked_fill( + torch.triu(torch.ones_like(att, dtype=torch.bool), diagonal=1), float("-inf") + ) # (B, nh, T, T) + att = F.softmax(att, dim=-1) # (B, nh, T, T) + att = self.attn_dropout(att) # (B, nh, T, T) + y = att @ v # (B, nh, T, hs) + y = y.transpose(1, 2) # (B, T, nh, hs) + return y def forward(self, x): @@ -247,6 +291,8 @@ def forward(self, x): y = self._fd_attention(c_x) elif self.attn_kernel_type == "torch_attn": y = self._torch_attn(c_x) + elif self.attn_kernel_type == "hand": + y = self._vanilla_attn(c_x) else: raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") diff --git a/fam/llm/mixins/causal.py b/fam/llm/mixins/causal.py index 4dd77ed..62ae9d5 100644 --- a/fam/llm/mixins/causal.py +++ b/fam/llm/mixins/causal.py @@ -59,7 +59,7 @@ def _sample_next_token( temperature: float, top_k: Optional[int], top_p: Optional[float], - guidance_scale: Optional[float], + guidance_scale: Optional[Tuple[float, float]], ) -> torch.Tensor: """ Predict the next token in the sequence. @@ -87,14 +87,22 @@ def _sample_next_token( ) # list with len num_hierarchies of (b,1,vocab_size) tensors if guidance_scale is not None: - assert idx_cond.shape[0] % 2 == 0 - assert list_logits[0].shape[0] % 2 == 0 + spkemb_guidance_scale, prompt_guidance_scale = guidance_scale + assert spkemb_guidance_scale >= 1 + assert prompt_guidance_scale >= 1 + base_scale = spkemb_guidance_scale + prompt_guidance_scale - 1 for i, logits in enumerate(list_logits): - logits_cond, logits_uncond = logits.split(logits.shape[0] // 2, dim=0) - list_logits[i] = (guidance_scale) * logits_cond + (1 - guidance_scale) * logits_uncond - - assert list_logits[0].shape[0] == idx_cond.shape[0] // 2 + if prompt_guidance_scale > 1: + logits_cond, logits_uncond_spkemb, logits_uncond_prompt = logits.split(logits.shape[0] // 3, dim=0) + else: + logits_cond, logits_uncond_spkemb = logits.split(logits.shape[0] // 2, dim=0) + logits_uncond_prompt = 0 + list_logits[i] = ( + (base_scale) * logits_cond + + (1 - spkemb_guidance_scale) * logits_uncond_spkemb + + (1 - prompt_guidance_scale) * logits_uncond_prompt + ) # pluck the logits at the final step and scale by desired temperature list_logits = [ @@ -178,7 +186,9 @@ def _sample_batch( top_k: Optional[int], top_p: Optional[float], speaker_embs: Optional[torch.Tensor], - guidance_scale: Optional[float], + guidance_scale: Optional[Tuple[float, float]], + end_of_audio_token: int, + end_of_text_token: int, ): """ Samples a batch of tokens from the model. @@ -202,33 +212,54 @@ def _sample_batch( min_seq_lens = min(seq_lens) idx = idx[:, :, :min_seq_lens] + idx_out = torch.full( + (idx.shape[0], idx.shape[1], idx.shape[2] + max_new_tokens), + end_of_audio_token, + dtype=idx.dtype, + device=idx.device, + ) + idx_out[:, :, :min_seq_lens] = idx + terminated = idx.new_zeros(idx.shape[0], dtype=torch.bool) if guidance_scale is not None: + _, prompt_guidance_scale = guidance_scale if speaker_embs is None: raise Exception("Guidance is only supported for conditional models") # create speaker embeddings equivalent to the batch size, filling with None # for second half to do unconditional generation. - speaker_embs = list(speaker_embs) + [None] * (speaker_embs.shape[0]) + speaker_embs = ( + list(speaker_embs) + + [None] * (speaker_embs.shape[0]) + + (list(speaker_embs) if prompt_guidance_scale > 1 else []) + ) for timestep in tqdm.tqdm(range(min_seq_lens, min_seq_lens + max_new_tokens), desc="tokens: "): + if terminated.all(): + break if (self.kv_cache_enabled is True) and (timestep > min_seq_lens): - idx_input = idx[:, :, -1:] + idx_input = idx_out[:, :, [timestep - 1]] else: - idx_input = idx + idx_input = idx_out[:, :, :timestep] if guidance_scale is not None: + _, prompt_guidance_scale = guidance_scale # TODO: fix: will cause a problem with kv-caching as it's not expecting larger batch-size. if timestep == min_seq_lens: - print("[hack!!!!] Guidance is on, so we're doubling batch size!") + print("[hack!!!!] Guidance is on, so we're doubling/tripling batch size!") # replicate idx in the batch dimension idx_input = ( - idx_input.unsqueeze(0).repeat(2, 1, 1, 1).reshape(-1, idx_input.shape[1], idx_input.shape[2]) + idx_input.unsqueeze(0) + .repeat(3 if prompt_guidance_scale > 1 else 2, 1, 1, 1) + .reshape(-1, idx_input.shape[1], idx_input.shape[2]) ) - # sanity checks - assert idx_input.shape[0] % 2 == 0 + if prompt_guidance_scale > 1: + idx_input_uncond = idx_input[idx_input.shape[0] // 3 * 2 :] + idx_input_uncond = idx_input_uncond.view(-1) + # Replace all text tokens with endoftext token + idx_input_uncond[idx_input_uncond > end_of_audio_token] = end_of_text_token idx_next = self._sample_next_token( idx=idx_input, @@ -247,12 +278,13 @@ def _sample_batch( orig_input_at_t=input[:, :, timestep], token_pred_mask_at_t=token_pred_mask[:, [timestep]], ) - - idx_next = idx_next.unsqueeze(-1) # (b, num_hierarchies, T=1) tensor + is_endofaudio = (idx_next == end_of_audio_token).any(dim=-1) # shape: b + terminated = terminated | is_endofaudio + idx_next[terminated] = end_of_audio_token # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=2) + idx_out[:, :, timestep] = idx_next - return idx + return idx_out @torch.no_grad() def _sort_for_batching( @@ -317,7 +349,10 @@ def _causal_sample( top_p: Optional[float], speaker_embs: Optional[torch.Tensor], batch_size: int, - guidance_scale: Optional[float] = None, + guidance_scale: Optional[Tuple[float, float]] = None, + dtype: torch.dtype = torch.bfloat16, + end_of_audio_token: int, + end_of_text_token: int, ) -> torch.Tensor: """ Generates a sequence of tokens using causal sampling. @@ -354,14 +389,16 @@ def _causal_sample( kv_batch_size = end_index - start_index if guidance_scale is not None: - kv_batch_size = 2 * kv_batch_size + if guidance_scale[1] > 1: + kv_batch_size = 3 * kv_batch_size + else: + kv_batch_size = 2 * kv_batch_size if self.kv_cache_enabled: - print("!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16") self.empty_kv_cache( batch_size=kv_batch_size, kv_cache_maxlen=self.config.block_size, - dtype=torch.bfloat16, + dtype=dtype, ) batch_seq_lens = seq_lens[start_index:end_index] @@ -379,6 +416,8 @@ def _causal_sample( top_p=top_p, speaker_embs=batch_speaker_embs, guidance_scale=guidance_scale, + end_of_audio_token=end_of_audio_token, + end_of_text_token=end_of_text_token, ) return_idx[start_index:end_index] = batch_idx diff --git a/fam/llm/model.py b/fam/llm/model.py index 90d69e4..c6e3c2f 100644 --- a/fam/llm/model.py +++ b/fam/llm/model.py @@ -1,7 +1,7 @@ import inspect import math from dataclasses import dataclass, field -from typing import Literal, Optional, Union +from typing import Literal, Optional, Tuple, Union import torch import torch.nn as nn @@ -350,7 +350,10 @@ def generate( top_p: Optional[float] = None, speaker_embs: Optional[torch.Tensor] = None, batch_size: Optional[int] = None, - guidance_scale: Optional[float] = None, + guidance_scale: Optional[Tuple[float, float]] = None, + dtype: torch.dtype = torch.bfloat16, + end_of_audio_token: int = 99999, # Dummy values will disable early termination / guidance features. + end_of_text_token: int = 99999, ): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,num_hierarchies,t)) and complete @@ -373,6 +376,9 @@ def generate( speaker_embs=speaker_embs, batch_size=batch_size, guidance_scale=guidance_scale, + dtype=dtype, + end_of_audio_token=end_of_audio_token, + end_of_text_token=end_of_text_token, ) else: @@ -400,4 +406,3 @@ def generate( ) ) return torch.cat(out, dim=0) - return torch.cat(out, dim=0) diff --git a/fam/llm/sample.py b/fam/llm/sample.py index b74774b..b36dd80 100644 --- a/fam/llm/sample.py +++ b/fam/llm/sample.py @@ -8,7 +8,7 @@ import tempfile from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Literal, Optional, Type, Union +from typing import List, Literal, Optional, Tuple, Type, Union import librosa import torch @@ -47,10 +47,6 @@ def __str__(self): class Model: - """ - Class to sample from a trained model. - """ - def __init__( self, config: InferenceConfig, @@ -71,14 +67,14 @@ def __init__( torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast - ptdtype = { + self.ptdtype = { "float32": torch.float32, "tfloat32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[config.dtype] self._ctx = ( - nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) + nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype) ) self.use_bpe_tokenizer = False @@ -156,10 +152,7 @@ def _init_model(self): block.attn.attn_kernel_type = "fd" elif self.use_kv_cache == "vanilla": for block in self.model.transformer.h: - if block.attn.attn_kernel_type != "fa2": - raise Exception( - f"kv_cache only supported for flash attention 2 but found {block.attn.attn_kernel_type} inside model!" - ) + block.attn.attn_kernel_type = "torch_attn" self.model.enable_kv_cache() else: raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!") @@ -245,6 +238,9 @@ def causal_sample( speaker_embs=speaker_embs, batch_size=batch_size, guidance_scale=guidance_scale, + dtype=self.ptdtype, + end_of_audio_token=self.tokenizer.offset - 1, + end_of_text_token=self.tokenizer.eot_token, ) for i in range(len(y)): to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True)) @@ -455,7 +451,7 @@ def _sample_utterance_batch( enhancer: Optional[Union[Literal["df"], BaseEnhancer]], first_stage_ckpt_path: str, second_stage_ckpt_path: str, - guidance_scale: Optional[float], + guidance_scale: Optional[Tuple[float, float]], max_new_tokens: int, top_k: Optional[int], top_p: Optional[float], @@ -533,7 +529,7 @@ def sample_utterance( enhancer: Optional[Union[Literal["df"], BaseEnhancer]], first_stage_ckpt_path: str, second_stage_ckpt_path: str, - guidance_scale: Optional[float], + guidance_scale: Optional[Tuple[float, float]], max_new_tokens: int, top_k: Optional[int], top_p: Optional[float], @@ -649,8 +645,8 @@ class SamplingControllerConfig: output_dir: str = "samples/" """Relative path to output directory""" - guidance_scale: Optional[float] = 3.0 - """Guidance scale for sampling.""" + guidance_scale: Optional[Tuple[float, float]] = (3.0, 1.0) + """Guidance scale for sampling: (speaker conditioning guidance_scale, prompt conditioning guidance scale).""" batch_size: int = 128 """Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model, diff --git a/fam/llm/serving.py b/fam/llm/serving.py index a6c8fb5..dc3f856 100644 --- a/fam/llm/serving.py +++ b/fam/llm/serving.py @@ -3,8 +3,9 @@ import shlex import subprocess import tempfile +import warnings from pathlib import Path -from typing import Literal, Optional +from typing import Literal, Optional, Tuple import fastapi import fastapi.middleware.cors @@ -73,7 +74,7 @@ class _GlobalState: @dataclass(frozen=True) class TTSRequest: text: str - guidance: Optional[float] = 3.0 + guidance: Optional[Tuple[float, float]] = (3.0, 1.0) top_p: Optional[float] = 0.95 speaker_ref_path: Optional[str] = None top_k: Optional[int] = None @@ -95,6 +96,9 @@ async def text_to_speech(req: Request): wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) else: wav_path = tts_req.speaker_ref_path + if wav_path is None: + warnings.warn("Running without speaker reference") + assert tts_req.guidance is None wav_out_path = sample_utterance( tts_req.text, wav_path, @@ -126,7 +130,8 @@ async def text_to_speech(req: Request): def _convert_audiodata_to_wav_path(audiodata, wav_tmp): with tempfile.NamedTemporaryFile() as unknown_format_tmp: - assert unknown_format_tmp.write(audiodata) > 0 + if unknown_format_tmp.write(audiodata) == 0: + return None unknown_format_tmp.flush() subprocess.check_output( @@ -161,7 +166,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): seed=1337, device=device, dtype=GlobalState.config.dtype, - compile=False, + compile=GlobalState.config.compile, init_from="resume", output_dir=tempfile.mkdtemp(), ) diff --git a/requirements.txt b/requirements.txt index e03bcdd..dfc3ce4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ tiktoken==0.5.1 audiocraft numpy ninja -flash-attn fastapi uvicorn tyro