Skip to content

Commit 00a4a88

Browse files
committed
Update
[ghstack-poisoned]
1 parent 367ca68 commit 00a4a88

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

examples/models/phi-3-mini/__init__.py

Whitespace-only changes.

examples/models/phi-3-mini/eager.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from transformers import AutoTokenizer, Phi3ForCausalLM
1616

17+
from .static_cache import ETStaticCache
18+
1719
end_of_text_token = 32000
1820

1921

@@ -40,7 +42,18 @@ def _generate_token(args, model, prompt_tokens):
4042
def _generate_token_with_kv_cache(args, model, prompt_tokens):
4143
print("Generating tokens:", end="", flush=True)
4244

43-
result = model.forward(input_ids=prompt_tokens, use_cache=True, return_dict=True)
45+
result = model.forward(
46+
input_ids=prompt_tokens,
47+
use_cache=True,
48+
return_dict=True,
49+
past_key_values=ETStaticCache(
50+
model.config,
51+
prompt_tokens.shape[0],
52+
args.seq_len + prompt_tokens.shape[-1],
53+
device=model.device,
54+
dtype=model.dtype,
55+
),
56+
)
4457

4558
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
4659
current_key_value = result.past_key_values
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Optional
7+
8+
import torch
9+
from transformers import PretrainedConfig, StaticCache
10+
11+
12+
class ETStaticCache(StaticCache):
13+
14+
def __init__(
15+
self,
16+
config: PretrainedConfig,
17+
max_batch_size: int,
18+
max_cache_len: int,
19+
device,
20+
dtype=torch.float32,
21+
) -> None:
22+
super().__init__(
23+
config=config,
24+
max_batch_size=max_batch_size,
25+
max_cache_len=max_cache_len,
26+
device=device,
27+
dtype=dtype,
28+
)
29+
30+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
31+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()
32+
33+
def get_usable_length(
34+
self, new_seq_length: int, layer_idx: Optional[int] = 0
35+
) -> int:
36+
return self.get_seq_length(layer_idx)

0 commit comments

Comments
 (0)