Skip to content

add a wrapper for running phi-3-mini with kv cache #4498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/models/phi-3-mini/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .phi_3_mini import Phi3Mini

__all__ = [
Phi3Mini,
]
29 changes: 8 additions & 21 deletions examples/models/phi-3-mini/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers import AutoTokenizer, Phi3ForCausalLM

from .static_cache import ETStaticCache
from .phi_3_mini import Phi3Mini

end_of_text_token = 32000

Expand Down Expand Up @@ -42,35 +42,22 @@ def _generate_token(args, model, prompt_tokens):
def _generate_token_with_kv_cache(args, model, prompt_tokens):
print("Generating tokens:", end="", flush=True)

result = model.forward(
input_ids=prompt_tokens,
use_cache=True,
return_dict=True,
past_key_values=ETStaticCache(
model.config,
prompt_tokens.shape[0],
args.seq_len + prompt_tokens.shape[-1],
device=model.device,
dtype=model.dtype,
),
)
model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])

current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values
for input_pos in range(prompt_tokens.shape[-1]):
result = model.forward(
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
)

current_token = torch.argmax(result, dim=-1).item()
print(f" {current_token}", end="", flush=True)

generated_tokens = [current_token]

while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
use_cache=True,
return_dict=True,
past_key_values=current_key_value,
)
current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item()
current_key_value = result.past_key_values
current_token = torch.argmax(result, dim=-1).item()
print(f" {current_token}", end="", flush=True)
generated_tokens.append(current_token)

Expand Down
36 changes: 36 additions & 0 deletions examples/models/phi-3-mini/phi_3_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn
from transformers import Phi3ForCausalLM

from .static_cache import ETStaticCache


class Phi3Mini(torch.nn.Module):

def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = ETStaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)

def forward(
self,
input_ids: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
).logits[:, -1, :]
Loading