From 6b50916c0ea28d3bdedd29f627f6acefeb12fe38 Mon Sep 17 00:00:00 2001 From: Optimox Date: Sat, 26 Oct 2024 11:59:13 +0200 Subject: [PATCH] fix mlp and kv cache, disable flex attention --- .../gemma2/27B_lora_single_device.yaml | 4 +-- .../configs/gemma2/2B_lora_single_device.yaml | 2 +- torchtune/models/gemma2/_attention.py | 22 +++++++++++-- torchtune/models/gemma2/_attention_utils.py | 19 ++++++----- .../models/gemma2/_component_builders.py | 33 ++++++++++--------- 5 files changed, 50 insertions(+), 30 deletions(-) diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml index 7879dd1fc..56727e529 100644 --- a/recipes/configs/gemma2/27B_lora_single_device.yaml +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -60,10 +60,10 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments -batch_size: 8 +batch_size: 2 epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 +gradient_accumulation_steps: 8 compile: False # Training env diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml index 49b59846c..484f133b4 100644 --- a/recipes/configs/gemma2/2B_lora_single_device.yaml +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -37,7 +37,7 @@ model: checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/gemma-2b/ + checkpoint_dir: /tmp/gemma-2-2b/ checkpoint_files: [ model-00001-of-00003.safetensors, model-00002-of-00003.safetensors, diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index e4d0949d0..c769c5b47 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -12,7 +12,11 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType from torchtune.modules.kv_cache import KVCache -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION + +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False if _SUPPORTS_FLEX_ATTENTION: from torch.nn.attention.flex_attention import create_block_mask @@ -132,6 +136,11 @@ def __init__( else: self.scaling = self.head_dim**-0.5 + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -156,6 +165,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -283,7 +293,7 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: + if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) q.mul_(self.scaling) @@ -436,6 +446,11 @@ def __init__( self.softcapping, self.scaling ) + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -460,6 +475,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -587,7 +603,7 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: + if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) # TODO: how to avoid to compute same block mask at every layer ? diff --git a/torchtune/models/gemma2/_attention_utils.py b/torchtune/models/gemma2/_attention_utils.py index 534ad9e05..8a17c8ec8 100644 --- a/torchtune/models/gemma2/_attention_utils.py +++ b/torchtune/models/gemma2/_attention_utils.py @@ -8,7 +8,10 @@ import torch -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False if _SUPPORTS_FLEX_ATTENTION: from functools import lru_cache @@ -19,7 +22,9 @@ flex_attention, ) - # flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + flex_attention_compiled = torch.compile( + flex_attention, dynamic=False, mode="max-autotune" + ) @lru_cache def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"): @@ -40,7 +45,7 @@ def compile_friendly_flex_attention_with_score_and_block( ) -> torch.Tensor: """ Flex attention does not seem to work with my A6000 with the default options. - Using proposed options here: https://github.com/pytorch/pytorch/issues/133254 + Using proposed options here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1815058279 """ return flex_attention( q, @@ -49,12 +54,8 @@ def compile_friendly_flex_attention_with_score_and_block( score_mod=score_mod, block_mask=block_mask, # kernel_options={ - # "BLOCK_M": 64, - # "BLOCK_N": 64, - # "BLOCK_M1": 32, - # "BLOCK_N1": 64, - # "BLOCK_M2": 64, - # "BLOCK_N2": 32, + # "BLOCK_M": 32, + # "BLOCK_N": 32, # }, ) diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 915430ce4..253cea2f8 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -22,7 +22,10 @@ from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp -from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +# The flex attention implementation for gemma2 is not working yet +# flex attention is disabled for now untill we solve the case +_SUPPORTS_FLEX_ATTENTION = False import logging from torchtune.utils._logging import get_logger, log_once @@ -132,12 +135,12 @@ def gemma2( """ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - layers = torch.nn.ModuleList() for layer_idx in range(num_layers): + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + self_att = _flex_or_native_gemma2_attention( embed_dim=embed_dim, num_heads=num_heads, @@ -244,18 +247,6 @@ def lora_gemma2( TransformerDecoder: Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer. """ - if apply_lora_to_mlp: - mlp = lora_gemma_mlp( - dim=embed_dim, - hidden_dim=intermediate_dim, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - use_dora=use_dora, - quantize_base=quantize_base, - ) - else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) @@ -263,6 +254,18 @@ def lora_gemma2( layers = torch.nn.ModuleList() for layer_idx in range(num_layers): + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) self_att = lora_gemma2_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim,