Skip to content

Commit

Permalink
fix mlp and kv cache, disable flex attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 26, 2024
1 parent 0d53660 commit 6b50916
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 30 deletions.
4 changes: 2 additions & 2 deletions recipes/configs/gemma2/27B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 8
batch_size: 2
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 2
gradient_accumulation_steps: 8
compile: False

# Training env
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma2/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_dir: /tmp/gemma-2-2b/
checkpoint_files: [
model-00001-of-00003.safetensors,
model-00002-of-00003.safetensors,
Expand Down
22 changes: 19 additions & 3 deletions torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from torch import nn
from torchtune.modules.attention_utils import _MaskType
from torchtune.modules.kv_cache import KVCache
from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION

# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION
# The flex attention implementation for gemma2 is not working yet
# flex attention is disabled for now untill we solve the case
_SUPPORTS_FLEX_ATTENTION = False

if _SUPPORTS_FLEX_ATTENTION:
from torch.nn.attention.flex_attention import create_block_mask
Expand Down Expand Up @@ -132,6 +136,11 @@ def __init__(
else:
self.scaling = self.head_dim**-0.5

# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False

def setup_cache(
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
) -> None:
Expand All @@ -156,6 +165,7 @@ def setup_cache(
head_dim=self.head_dim,
dtype=dtype,
)
self.cache_enabled = True

def reset_cache(self):
"""Reset the key value caches."""
Expand Down Expand Up @@ -283,7 +293,7 @@ def forward(
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

q.mul_(self.scaling)
Expand Down Expand Up @@ -436,6 +446,11 @@ def __init__(
self.softcapping, self.scaling
)

# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False

def setup_cache(
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
) -> None:
Expand All @@ -460,6 +475,7 @@ def setup_cache(
head_dim=self.head_dim,
dtype=dtype,
)
self.cache_enabled = True

def reset_cache(self):
"""Reset the key value caches."""
Expand Down Expand Up @@ -587,7 +603,7 @@ def forward(
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# TODO: how to avoid to compute same block mask at every layer ?
Expand Down
19 changes: 10 additions & 9 deletions torchtune/models/gemma2/_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import torch

from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION
# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION
# The flex attention implementation for gemma2 is not working yet
# flex attention is disabled for now untill we solve the case
_SUPPORTS_FLEX_ATTENTION = False

if _SUPPORTS_FLEX_ATTENTION:
from functools import lru_cache
Expand All @@ -19,7 +22,9 @@
flex_attention,
)

# flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
flex_attention_compiled = torch.compile(
flex_attention, dynamic=False, mode="max-autotune"
)

@lru_cache
def create_block_mask_cached(score_mod, b, h, m, n, device="cuda"):
Expand All @@ -40,7 +45,7 @@ def compile_friendly_flex_attention_with_score_and_block(
) -> torch.Tensor:
"""
Flex attention does not seem to work with my A6000 with the default options.
Using proposed options here: https://github.com/pytorch/pytorch/issues/133254
Using proposed options here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1815058279
"""
return flex_attention(
q,
Expand All @@ -49,12 +54,8 @@ def compile_friendly_flex_attention_with_score_and_block(
score_mod=score_mod,
block_mask=block_mask,
# kernel_options={
# "BLOCK_M": 64,
# "BLOCK_N": 64,
# "BLOCK_M1": 32,
# "BLOCK_N1": 64,
# "BLOCK_M2": 64,
# "BLOCK_N2": 32,
# "BLOCK_M": 32,
# "BLOCK_N": 32,
# },
)

Expand Down
33 changes: 18 additions & 15 deletions torchtune/models/gemma2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp
from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION
# from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION
# The flex attention implementation for gemma2 is not working yet
# flex attention is disabled for now untill we solve the case
_SUPPORTS_FLEX_ATTENTION = False

import logging
from torchtune.utils._logging import get_logger, log_once
Expand Down Expand Up @@ -132,12 +135,12 @@ def gemma2(
"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)

mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)

layers = torch.nn.ModuleList()

for layer_idx in range(num_layers):

mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)

self_att = _flex_or_native_gemma2_attention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand Down Expand Up @@ -244,25 +247,25 @@ def lora_gemma2(
TransformerDecoder: Instantiation of Gemma model with LoRA applied to
a subset of the attention projections in each layer.
"""
if apply_lora_to_mlp:
mlp = lora_gemma_mlp(
dim=embed_dim,
hidden_dim=intermediate_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)

tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
output_proj = TiedLinear(tok_embeddings)

layers = torch.nn.ModuleList()

for layer_idx in range(num_layers):
if apply_lora_to_mlp:
mlp = lora_gemma_mlp(
dim=embed_dim,
hidden_dim=intermediate_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)
self_att = lora_gemma2_self_attention(
lora_modules=lora_attn_modules,
embed_dim=embed_dim,
Expand Down

0 comments on commit 6b50916

Please sign in to comment.