Skip to content

Commit 5b37524

Browse files
helunwencserfacebook-github-bot
authored andcommitted
Add customized static cache implementation (#4490)
Summary: Pull Request resolved: #4490 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: iseeyuan Differential Revision: D60554455 Pulled By: helunwencser fbshipit-source-id: defc2953afb265b5e21b2fa540c3b1eb2e90d0a8
1 parent 1114539 commit 5b37524

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

examples/models/phi-3-mini/__init__.py

Whitespace-only changes.

examples/models/phi-3-mini/eager.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from transformers import AutoTokenizer, Phi3ForCausalLM
1616

17+
from .static_cache import ETStaticCache
18+
1719
end_of_text_token = 32000
1820

1921

@@ -40,7 +42,18 @@ def _generate_token(args, model, prompt_tokens):
4042
def _generate_token_with_kv_cache(args, model, prompt_tokens):
4143
print("Generating tokens:", end="", flush=True)
4244

43-
result = model.forward(input_ids=prompt_tokens, use_cache=True, return_dict=True)
45+
result = model.forward(
46+
input_ids=prompt_tokens,
47+
use_cache=True,
48+
return_dict=True,
49+
past_key_values=ETStaticCache(
50+
model.config,
51+
prompt_tokens.shape[0],
52+
args.seq_len + prompt_tokens.shape[-1],
53+
device=model.device,
54+
dtype=model.dtype,
55+
),
56+
)
4457

4558
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
4659
current_key_value = result.past_key_values
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Optional
9+
10+
import torch
11+
from transformers import PretrainedConfig, StaticCache
12+
13+
14+
class ETStaticCache(StaticCache):
15+
"""
16+
A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch.
17+
This can be removed once transformers supports static cache for Phi3 properly.
18+
"""
19+
20+
def __init__(
21+
self,
22+
config: PretrainedConfig,
23+
max_batch_size: int,
24+
max_cache_len: int,
25+
device,
26+
dtype=torch.float32,
27+
) -> None:
28+
super().__init__(
29+
config=config,
30+
max_batch_size=max_batch_size,
31+
max_cache_len=max_cache_len,
32+
device=device,
33+
dtype=dtype,
34+
)
35+
36+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
37+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()
38+
39+
def get_usable_length(
40+
self, new_seq_length: int, layer_idx: Optional[int] = 0
41+
) -> int:
42+
return self.get_seq_length(layer_idx)

0 commit comments

Comments
 (0)