Skip to content

[executorch] generation.py with kv cache #3030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
370 changes: 370 additions & 0 deletions examples/models/llama2/runner/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import argparse

import json
from typing import List, Optional, Tuple, TypedDict

import torch
import torch.nn.functional as F
from executorch.examples.models.llama2.llama_transformer import ModelArgs

from executorch.examples.models.llama2.tokenizer.tiktoken import (
Dialog,
Message,
Tokenizer,
)
from executorch.extension.pybindings.portable_lib import _load_for_executorch


class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required


class ChatPrediction(TypedDict, total=False):
generation: Message
tokens: List[str] # not required
logprobs: List[float] # not required


def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.

Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.

Returns:
torch.Tensor: Sampled token indices.

Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token


class LlamaRunner:
def __init__(self, model_path: str, tokenizer_path: str, model_args: ModelArgs):
# model is a pte file.
self.model = _load_for_executorch(model_path)
self.params = model_args
self.tokenizer = Tokenizer(tokenizer_path)
assert model_args.vocab_size == self.tokenizer.n_words

def generate( # noqa: C901
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
bsz = len(prompt_tokens)
params = self.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)

assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
if self.params.use_kv_cache:
min_prompt_len = 1

eos_reached = torch.tensor([False] * bsz, device="cpu")
input_text_mask = tokens != pad_id
pos = torch.tensor([prev_pos], dtype=torch.int64)
if min_prompt_len == total_len:
if self.params.use_kv_cache:
inputs = (tokens, pos)
else:
inputs = (tokens,)
logits = self.model.forward(inputs) # updated forward call.
logits = logits[0]
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))

for cur_pos in range(min_prompt_len, total_len):
pos = torch.tensor([prev_pos], dtype=torch.int64)
if self.params.use_kv_cache:
inputs = (tokens[:, prev_pos:cur_pos], pos)
else:
inputs = (tokens[:, :cur_pos],)
logits = self.model.forward(inputs) # updated forward call.
logits = logits[0]
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)

next_token = next_token.reshape(-1)

# only replace token if prompt has already been generated
if not self.params.use_kv_cache or cur_pos < len(prompt_tokens[0]):
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)

tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
prev_pos = cur_pos
if all(eos_reached):
break

if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)

def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> List[CompletionPrediction]:
"""
Perform text completion for a list of prompts using the language generation model.

Args:
prompts (List[str]): List of text prompts for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

Returns:
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.

Note:
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)

if logprobs:
return [
{
"generation": self.tokenizer.decode(t),
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]

def chat_completion(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
"""
Generate assistant responses for a list of conversational dialogs using the language generation model.

Args:
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
If not provided, it's set to the model's maximum sequence length minus 1.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.

Returns:
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.

Raises:
AssertionError: If the last message in a dialog is not from the user.
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.

Note:
This method generates assistant responses for the provided conversational dialogs.
It employs nucleus sampling to introduce controlled randomness in text generation.
If logprobs is True, token log probabilities are computed for each generated token.
"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1

prompt_tokens = [
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
)
if logprobs:
return [
{
"generation": {
"role": "assistant",
"content": self.tokenizer.decode(t),
},
"tokens": [self.tokenizer.decode([x]) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [
{
"generation": {
"role": "assistant",
"content": self.tokenizer.decode(t),
},
}
for t in generation_tokens
]


def build_args_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument(
"-f",
"--pte",
type=str,
default=None,
help="path to exported executorch .pte file",
)

parser.add_argument(
"-p", "--params", type=str, default=None, help="model params file"
)

parser.add_argument(
"-t",
"--tokenizer",
type=str,
default=None,
)

parser.add_argument(
"--prompt",
type=str,
default="Hello",
)

parser.add_argument(
"--temperature",
type=float,
default=0.6,
)

parser.add_argument(
"-kv",
"--kv_cache",
default=False,
action="store_true",
)

parser.add_argument(
"--max_gen_len",
type=int,
default=10,
help="Maximum length of the generated response sequence.",
)

return parser


def main() -> None:
parser = build_args_parser()
args = parser.parse_args()

with open(args.params, "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=128,
max_batch_size=1,
use_kv_cache=args.kv_cache,
**params,
)
runner = LlamaRunner(
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
)
result = runner.text_completion(
prompts=[args.prompt],
max_gen_len=args.max_gen_len,
temperature=args.temperature,
)
print(f"Result: {result}")


if __name__ == "__main__":
main() # pragma: no cover