Skip to content

Commit

Permalink
refine append cache
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Sep 27, 2023
1 parent c433d50 commit 6d79c53
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/llm/src/bigdl/llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,


def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_size = (cache_k.size(0),
cache_k.size(1),
cache_k.size(2) + key_states.size(2),
cache_k.size(3))
size_0, size_1, old_length, size_3 = cache_k.size()
k_size_2 = key_states.size(2)
new_length = old_length + k_size_2
new_size = (size_0,
size_1,
new_length,
size_3)
new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0)
new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states
new_cache_k[:, :, old_length:new_length, :] = key_states
new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
new_cache_v[:, :, old_length:new_length, :] = value_states
return new_cache_k, new_cache_v


Expand Down

0 comments on commit 6d79c53

Please sign in to comment.