forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] MLPSpeculator speculative decoding support (vllm-project#4947)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
- Loading branch information
1 parent
a61199a
commit 4e1b011
Showing
18 changed files
with
523 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import gc | ||
import time | ||
from typing import List | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def time_generation(llm: LLM, prompts: List[str], | ||
sampling_params: SamplingParams): | ||
# Generate texts from the prompts. The output is a list of RequestOutput | ||
# objects that contain the prompt, generated text, and other information. | ||
# Warmup first | ||
llm.generate(prompts, sampling_params) | ||
llm.generate(prompts, sampling_params) | ||
start = time.time() | ||
outputs = llm.generate(prompts, sampling_params) | ||
end = time.time() | ||
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) | ||
# Print the outputs. | ||
for output in outputs: | ||
generated_text = output.outputs[0].text | ||
print(f"text: {generated_text!r}") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
template = ( | ||
"Below is an instruction that describes a task. Write a response " | ||
"that appropriately completes the request.\n\n### Instruction:\n{}" | ||
"\n\n### Response:\n") | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Write about the president of the United States.", | ||
] | ||
prompts = [template.format(prompt) for prompt in prompts] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.0, max_tokens=200) | ||
|
||
# Create an LLM without spec decoding | ||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") | ||
|
||
print("Without speculation") | ||
time_generation(llm, prompts, sampling_params) | ||
|
||
del llm | ||
gc.collect() | ||
|
||
# Create an LLM with spec decoding | ||
llm = LLM( | ||
model="meta-llama/Llama-2-13b-chat-hf", | ||
speculative_model="ibm-fms/llama-13b-accelerator", | ||
# These are currently required for MLPSpeculator decoding | ||
use_v2_block_manager=True, | ||
enforce_eager=True, | ||
) | ||
|
||
print("With speculation") | ||
time_generation(llm, prompts, sampling_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import math | ||
from typing import Iterable, List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor import SamplingMetadata | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.sequence import SamplerOutput | ||
|
||
|
||
class MLPSpeculatorLayerNorm(nn.Module): | ||
""" | ||
A L2 normalization implementation | ||
... | ||
Args | ||
---- | ||
normalized_shape : int | ||
Dimensionality of input data (size of final tensor axis) | ||
eps : float | ||
Safety term to prevent division by zero. Make sure the chosen value | ||
fits in the range of your encoding scheme | ||
(i.e. fp16 requires eps >= 6e-8). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
normalized_shape, | ||
eps=1e-06, | ||
): | ||
super(MLPSpeculatorLayerNorm, self).__init__() | ||
self.weight = nn.Parameter(torch.empty(normalized_shape)) | ||
self.bias = nn.Parameter(torch.empty(normalized_shape)) | ||
self.eps = eps | ||
|
||
def forward(self, x): | ||
xf = x | ||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) | ||
x = xf.type_as(x) | ||
x = self.weight * x | ||
x = x + self.bias | ||
return x | ||
|
||
|
||
class MLPSpeculator(nn.Module): | ||
|
||
def __init__(self, config, **kwargs) -> None: | ||
super().__init__() | ||
self.n_predict = config.n_predict | ||
self.vocab_size = config.vocab_size | ||
self.emb_dim = config.emb_dim | ||
self.inner_dim = config.inner_dim if config.inner_dim != 0 \ | ||
else config.emb_dim | ||
|
||
self.max_speculative_tokens = getattr(config, "max_speculative_tokens", | ||
self.n_predict) | ||
|
||
self.emb = nn.ModuleList([ | ||
VocabParallelEmbedding(config.vocab_size, | ||
self.inner_dim, | ||
org_num_embeddings=config.vocab_size) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.proj = nn.ModuleList([ | ||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), | ||
self.inner_dim, | ||
bias=False) for i in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.head = nn.ModuleList([ | ||
nn.Linear(self.inner_dim, self.vocab_size, bias=False) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
self.ln = nn.ModuleList([ | ||
MLPSpeculatorLayerNorm(self.inner_dim) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.state_weight = 0.5**(0.5 / config.n_predict) | ||
self.emb_weight = math.sqrt( | ||
(1 - self.state_weight**2) * (self.inner_dim / 2)) | ||
self.activation = nn.GELU() | ||
self.config = config | ||
self.logits_processor = LogitsProcessor(config.vocab_size, | ||
config.vocab_size, 1.0) | ||
self.sampler = Sampler() | ||
|
||
def generate_proposals( | ||
self, | ||
input_ids: torch.Tensor, | ||
previous_hidden_states: torch.Tensor, | ||
num_predict_tokens: int, | ||
sampling_metadata: SamplingMetadata, | ||
) -> List[SamplerOutput]: | ||
if num_predict_tokens > self.max_speculative_tokens: | ||
raise ValueError(f"Max speculative tokens for model is " | ||
f"{self.max_speculative_tokens}, but " | ||
f"{num_predict_tokens} were requested") | ||
|
||
# b x 1 x d | ||
previous_hidden_states = previous_hidden_states.unsqueeze(1) | ||
|
||
# b x 1 | ||
last_tokens = input_ids.unsqueeze(1) | ||
|
||
next_tokens = [] | ||
|
||
for head_index in range(num_predict_tokens): | ||
|
||
# Project and predict | ||
z = self.emb[head_index](last_tokens) # b k d | ||
states = self.proj[head_index](previous_hidden_states) | ||
|
||
# Weighted add of state_weight*state and emb_weight*z | ||
# Let subsequent LN take care of denominator | ||
# state_weight is close to 1, so shouldn't be any precision issues | ||
states.add_(z, alpha=self.emb_weight / self.state_weight) | ||
|
||
states = self.activation(self.ln[head_index](states)) # b k d | ||
# TODO: not yet supporting top_k_tokens_per_head | ||
previous_hidden_states = states | ||
|
||
logits = self.logits_processor(self.head[head_index].weight, | ||
states, sampling_metadata) | ||
|
||
output = self.sampler(logits.flatten(0, 1), sampling_metadata) | ||
last_tokens = output.sampled_token_ids | ||
next_tokens.append(output) | ||
|
||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
param = params_dict[name.replace("speculator.", "")] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) |
Oops, something went wrong.