From 6f89920c35b2a48b68267265c40ea7bb0c65c7e9 Mon Sep 17 00:00:00 2001 From: Optimox Date: Tue, 22 Oct 2024 11:05:47 +0200 Subject: [PATCH] WIP: non working flex attention --- recipes/configs/gemma2/2B_full.yaml | 6 +- recipes/configs/gemma2/2B_lora.yaml | 6 +- .../configs/gemma2/2B_lora_single_device.yaml | 8 +- .../gemma2/2B_qlora_single_device.yaml | 6 +- recipes/lora_finetune_single_device.py | 1 - torchtune/models/gemma2/_attention.py | 309 +++++++++++++++++- torchtune/models/gemma2/_attention_utils.py | 96 ++++++ .../models/gemma2/_component_builders.py | 26 +- 8 files changed, 438 insertions(+), 20 deletions(-) create mode 100644 torchtune/models/gemma2/_attention_utils.py diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index f1214810a9..9386fae4b9 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -19,7 +19,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -33,14 +33,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml index ca6d8df232..e6ef6e6e9e 100644 --- a/recipes/configs/gemma2/2B_lora.yaml +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -37,14 +37,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index d8bbeb9a81..49b59846c4 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -44,7 +44,7 @@ checkpointer: model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False @@ -62,10 +62,10 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments -batch_size: 4 +batch_size: 8 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 2 compile: False # Training env diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml index c65367419f..b5d7c9147d 100644 --- a/recipes/configs/gemma2/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -18,7 +18,7 @@ # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer - path: /tmp/gemma2-2b/tokenizer.model + path: /tmp/gemma-2-2b/tokenizer.model # Dataset dataset: @@ -37,14 +37,14 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma2-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, model-00003-of-00003.safetensors, ] recipe_checkpoint: null - output_dir: /tmp/gemma2-2b + output_dir: /tmp/gemma-2-2b model_type: GEMMA2 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 5d39b72086..4f567e2c9a 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -616,7 +616,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - # run model with self.activations_handling_ctx: logits = self._model(**batch) diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index c83212f7b5..e4d0949d0f 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -12,8 +12,15 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType from torchtune.modules.kv_cache import KVCache - - +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +if _SUPPORTS_FLEX_ATTENTION: + from torch.nn.attention.flex_attention import create_block_mask + from torchtune.models.gemma2._attention_utils import ( + compile_friendly_flex_attention_with_score_and_block, + flex_causal_sliding_window, + flex_tanh_soft_capping_with_scaling, + ) logger = logging.getLogger(__name__) @@ -282,8 +289,18 @@ def forward( q.mul_(self.scaling) output = torch.matmul(q, k.transpose(2, 3)) + # if mask is None: default to causal mask + if mask is None: + mask = torch.tril( + torch.ones( + size=(s_x, s_x), + dtype=torch.bool, + ).to(x.device) + ) + if self.sliding_window_size is not None: all_ones = torch.ones_like(mask) + sliding_mask = torch.triu( all_ones, -1 * self.sliding_window_size + 1 ) * torch.tril(all_ones, self.sliding_window_size - 1) @@ -303,3 +320,291 @@ def forward( # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output) + + +class FlexGemma2Attention(nn.Module): + """ + Adapated from official Google Pytorch Implementation: + https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 + to match torchtune style. + A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied + softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed + query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50.0, + query_pre_attn_scalar: Optional[int] = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # gemma related parameters + self.sliding_window_size = sliding_window_size + self.softcapping = softcapping + if query_pre_attn_scalar is not None: + # flex attention will always make the head_dim**-0.5 normalization so it should be included in scaling + self.scaling = query_pre_attn_scalar**-0.5 / self.head_dim**-0.5 + else: + self.scaling = None + + self.mask_mod = flex_causal_sliding_window(self.sliding_window_size) + self.score_mod = flex_tanh_soft_capping_with_scaling( + self.softcapping, self.scaling + ) + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # [b, n_h, s_x, h_d] + q = q.transpose(1, 2) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [b, s_y, n_kv, 1, h_d] + # v: [b, s_y, n_kv, 1, h_d] + k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + + # [b, s, n_h, h_d] + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None: + k, v = self.kv_cache.update(k, v) + + # TODO: how to avoid to compute same block mask at every layer ? + # https://pytorch.org/blog/flexattention/#q-when-should-we-recompute-the-blockmask + block_mask = create_block_mask( + mask_mod=self.mask_mod, + B=b, + H=self.num_heads, + Q_LEN=s_x, + KV_LEN=s_x, + device=q.device, + ) + + output = compile_friendly_flex_attention_with_score_and_block( + q, k, v, score_mod=self.score_mod, block_mask=block_mask + ) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/gemma2/_attention_utils.py b/torchtune/models/gemma2/_attention_utils.py new file mode 100644 index 0000000000..534ad9e051 --- /dev/null +++ b/torchtune/models/gemma2/_attention_utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch + +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +if _SUPPORTS_FLEX_ATTENTION: + from functools import lru_cache + + from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, + ) + + # flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + + @lru_cache + def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"): + block_mask = create_block_mask(score_mod, b, h, m, n, device=device) + return block_mask + + # We cannot do nested compile, but flex attention only has perf benefits + # when compiled. To insulate it from the compiler, we wrap it with + # compiler.disable so that it can be used regardless of whether the model + # is compiled or not, and flex attention always remains compiled. + @torch.compiler.disable(recursive=False) + def compile_friendly_flex_attention_with_score_and_block( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_mask: BlockMask, + score_mod: Any, + ) -> torch.Tensor: + """ + Flex attention does not seem to work with my A6000 with the default options. + Using proposed options here: https://github.com/pytorch/pytorch/issues/133254 + """ + return flex_attention( + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + # kernel_options={ + # "BLOCK_M": 64, + # "BLOCK_N": 64, + # "BLOCK_M1": 32, + # "BLOCK_N1": 64, + # "BLOCK_M2": 64, + # "BLOCK_N2": 32, + # }, + ) + + +def flex_causal_sliding_window(sliding_window_size): + def sliding_window_causal_mask(b, h, q_idx, kv_idx): + """Causal mask and sliding window as proposed here: + https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb + """ + causal_mask = q_idx >= kv_idx + if sliding_window_size is None: + # if no sliding window return causal mask + return causal_mask + else: + windowed_mask = q_idx - kv_idx <= sliding_window_size + + return causal_mask & windowed_mask + + return sliding_window_causal_mask + + +def flex_tanh_soft_capping_with_scaling(softcapping, query_pre_attn_scalar): + def tanh_soft_capping_with_scaling(score, b, h, q_idx, kv_idx): + """ + This handle both simple tanh soft capping and custom scaling + """ + if query_pre_attn_scalar is None: + # usual scaling included in FlexAttention + # TODO: could be made faster with approximate tanh ? + # https://github.com/pytorch-labs/attention-gym/blob/f7c93ded4abf9fd8d7dc9d8bcbf57e420b891e2d/examples/flex_attn.ipynb#L733 + score = score / softcapping + score = torch.tanh(score) + return score * softcapping + else: + score = score / softcapping * query_pre_attn_scalar**-0.5 + score = torch.tanh(score) + return score * softcapping + + return tanh_soft_capping_with_scaling diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 6478d8ec31..915430ce4a 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -22,7 +22,24 @@ from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp +from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +import logging +from torchtune.utils._logging import get_logger, log_once + +_log: logging.Logger = get_logger() + + +if _SUPPORTS_FLEX_ATTENTION: + from torchtune.models.gemma2._attention import FlexGemma2Attention + log_once( + _log, + "Using flex attention for Gemma2 attention computation.", + level=logging.DEBUG, + ) + _flex_or_native_gemma2_attention = FlexGemma2Attention +else: + _flex_or_native_gemma2_attention = Gemma2Attention """ Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. @@ -47,7 +64,7 @@ def forward(self, attn_weights): attn_weights = attn_weights / self.capping_value attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.capping_value - + return attn_weights class Gemma2FinalNorm(nn.Module): """ @@ -120,7 +137,8 @@ def gemma2( layers = torch.nn.ModuleList() for layer_idx in range(num_layers): - self_att = Gemma2Attention( + + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, @@ -316,7 +334,7 @@ def lora_gemma2_self_attention( use_dora: bool = False, quantize_base: bool = False, -) -> Gemma2Attention: +) -> _flex_or_native_gemma2_attention: if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" @@ -392,7 +410,7 @@ def lora_gemma2_self_attention( rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - self_att = Gemma2Attention( + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads,