Skip to content

Commit

Permalink
Revert LlamaKVCache due to memory increase (#1605)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiminha authored and regisss committed Dec 13, 2024
1 parent f8b496a commit 4e56c47
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module
from ..modeling_all_models import Matmul, apply_customized_rope_module
from .configuration_llama import LlamaConfig


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa

Expand All @@ -57,6 +58,7 @@

import habana_frameworks.torch.core as htcore


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -382,7 +384,23 @@ def forward(
padding_side,
)

class LlamaKVCache(KVCache):

class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
Expand All @@ -399,6 +417,15 @@ def update(prev, cur, dim, idx, inp_seq_len):
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return fused_scaled_dot_product_attention_distributed
Expand All @@ -412,8 +439,8 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = LlamaKVCache()
self.v_cache = LlamaKVCache()
self.k_cache = KVCache()
self.v_cache = KVCache()

if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
Expand Down

0 comments on commit 4e56c47

Please sign in to comment.