-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] MLPSpeculator speculative decoding support #4947
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
efb0599
initial commit of mlp_speculator and hidden_states_worker to support …
JRosenkranz 7a8eeff
Merge branch 'main' into mlp_speculator
JRosenkranz 667ef88
removed fms_extras import
JRosenkranz d534ef2
updated with a working non-batch version - a lot hardcoded
JRosenkranz 17541b6
updated experimental with working version - eager
JRosenkranz ac5a1da
fixed bug with speculator outputs
JRosenkranz 6ba9a1e
removed comments; swapped to sampling in the example
JRosenkranz cb3aacf
Introduce MLPSpeculatorWorker and corresponding refactor
tdoublep bf2f102
Fix some issues with correctness + simplify API a bit
tdoublep 6af4629
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill e0309a6
Fix typing and formatting
njhill abd42e7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill 314f2ae
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill 9dd1c50
Remove separate MLPSpeculatorModelRunner and other cleanup
njhill 0d43097
Use sample_len in mlp_speculator
njhill 9dd1608
Some more rework/simplification, still in progress
njhill ea677bd
Config cleanup
njhill b39c94f
Ignore weird mypi error only happening in CI
njhill ab96c2a
Try again to ignore weird ruff error
njhill e9af7e5
Try to ignore both ruff and mypy errs
njhill 30dc5e6
yapf
njhill 3a61052
Fix leftover HiddenStatesWorker references
njhill cc05972
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill 693974e
Fix AutoConfig import, mlp spec worker docstring
njhill f1bafba
Some cleanup/simplification
njhill 455b9a9
Rework handling of accepted tokens
njhill e583ae9
Filter hidden states in Top1Proposer when needed
njhill 7bff0d1
Enable bonus token
njhill 3d04037
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill bea97d7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill 3012553
Move hidden state logic to separate class
njhill b116e02
Default num_speculative_tokens based on speculator model config
njhill e7742e7
Move offline_inference example to separate file
njhill ee83331
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill 444a709
ruff
njhill bb9fd32
Add comment per review
njhill fcc6606
Some simplification to MLPSpeculatorWorker._prepare_input_tensors
njhill ffc0bcf
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill f3dc40a
Add check for TP == 1; TP support will be a fast-follow
njhill 1b7e305
Fix test import
njhill 46ceacd
Revert unrelated commit made by mistake
njhill d9ce339
Fix test mocks
njhill File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import gc | ||
import time | ||
from typing import List | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def time_generation(llm: LLM, prompts: List[str], | ||
sampling_params: SamplingParams): | ||
# Generate texts from the prompts. The output is a list of RequestOutput | ||
# objects that contain the prompt, generated text, and other information. | ||
# Warmup first | ||
llm.generate(prompts, sampling_params) | ||
llm.generate(prompts, sampling_params) | ||
start = time.time() | ||
outputs = llm.generate(prompts, sampling_params) | ||
end = time.time() | ||
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])) | ||
# Print the outputs. | ||
for output in outputs: | ||
generated_text = output.outputs[0].text | ||
print(f"text: {generated_text!r}") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
template = ( | ||
"Below is an instruction that describes a task. Write a response " | ||
"that appropriately completes the request.\n\n### Instruction:\n{}" | ||
"\n\n### Response:\n") | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Write about the president of the United States.", | ||
] | ||
prompts = [template.format(prompt) for prompt in prompts] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.0, max_tokens=200) | ||
|
||
# Create an LLM without spec decoding | ||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") | ||
|
||
print("Without speculation") | ||
time_generation(llm, prompts, sampling_params) | ||
|
||
del llm | ||
gc.collect() | ||
|
||
# Create an LLM with spec decoding | ||
llm = LLM( | ||
model="meta-llama/Llama-2-13b-chat-hf", | ||
speculative_model="ibm-fms/llama-13b-accelerator", | ||
# These are currently required for MLPSpeculator decoding | ||
use_v2_block_manager=True, | ||
enforce_eager=True, | ||
) | ||
|
||
print("With speculation") | ||
time_generation(llm, prompts, sampling_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -229,15 +229,17 @@ def verify_with_parallel_config( | |
self, | ||
parallel_config: "ParallelConfig", | ||
) -> None: | ||
total_num_attention_heads = self.hf_text_config.num_attention_heads | ||
total_num_attention_heads = getattr(self.hf_text_config, | ||
"num_attention_heads", 0) | ||
tensor_parallel_size = parallel_config.tensor_parallel_size | ||
if total_num_attention_heads % tensor_parallel_size != 0: | ||
raise ValueError( | ||
f"Total number of attention heads ({total_num_attention_heads})" | ||
" must be divisible by tensor parallel size " | ||
f"({tensor_parallel_size}).") | ||
|
||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers | ||
total_num_hidden_layers = getattr(self.hf_text_config, | ||
"num_hidden_layers", 0) | ||
pipeline_parallel_size = parallel_config.pipeline_parallel_size | ||
if total_num_hidden_layers % pipeline_parallel_size != 0: | ||
raise ValueError( | ||
|
@@ -336,8 +338,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: | |
|
||
def get_num_attention_heads(self, | ||
parallel_config: "ParallelConfig") -> int: | ||
return self.hf_text_config.num_attention_heads // \ | ||
parallel_config.tensor_parallel_size | ||
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) | ||
return num_heads // parallel_config.tensor_parallel_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it seems There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our speculator models do not use attention :) |
||
|
||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: | ||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers | ||
|
@@ -813,7 +815,8 @@ def maybe_create_spec_config( | |
speculative_model (Optional[str]): The name of the speculative | ||
model, if provided. | ||
num_speculative_tokens (Optional[int]): The number of speculative | ||
tokens, if provided. | ||
tokens, if provided. Will default to the number in the draft | ||
model config if present, otherwise is required. | ||
speculative_max_model_len (Optional[int]): The maximum model len of | ||
the speculative model. Used when testing the ability to skip | ||
speculation for some sequences. | ||
|
@@ -836,24 +839,18 @@ def maybe_create_spec_config( | |
the necessary conditions are met, else None. | ||
""" | ||
|
||
if speculative_model is None and num_speculative_tokens is None: | ||
if speculative_model is None: | ||
if num_speculative_tokens is not None: | ||
raise ValueError("num_speculative_tokens was provided without " | ||
"speculative_model.") | ||
return None | ||
|
||
if speculative_model is not None and num_speculative_tokens is None: | ||
raise ValueError( | ||
"Expected both speculative_model and " | ||
"num_speculative_tokens to be provided, but found " | ||
f"{speculative_model=} and {num_speculative_tokens=}.") | ||
|
||
if (speculative_disable_by_batch_size is not None | ||
and speculative_disable_by_batch_size < 2): | ||
raise ValueError("Expect the batch size threshold of disabling " | ||
"speculative decoding is > 1, but got " | ||
f"{speculative_disable_by_batch_size=}") | ||
|
||
assert (speculative_model is not None | ||
and num_speculative_tokens is not None) | ||
|
||
if enable_chunked_prefill: | ||
raise ValueError( | ||
"Speculative decoding and chunked prefill are " | ||
|
@@ -907,6 +904,27 @@ def maybe_create_spec_config( | |
max_logprobs=target_model_config.max_logprobs, | ||
) | ||
|
||
if (draft_model_config.hf_config.model_type == "mlp_speculator" | ||
and target_parallel_config.world_size != 1): | ||
# MLPSpeculator TP support will be added very soon | ||
raise ValueError( | ||
"Speculative decoding with mlp_speculator models does not " | ||
"yet support distributed inferencing (TP > 1).") | ||
|
||
n_predict = getattr(draft_model_config.hf_config, "n_predict", | ||
None) | ||
if n_predict is not None: | ||
if num_speculative_tokens is None: | ||
# Default to max value defined in draft model config. | ||
num_speculative_tokens = n_predict | ||
elif num_speculative_tokens > n_predict: | ||
# Verify provided value doesn't exceed the maximum | ||
# supported by the draft model. | ||
raise ValueError( | ||
"Expected both speculative_model and " | ||
"num_speculative_tokens to be provided, but found " | ||
f"{speculative_model=} and {num_speculative_tokens=}.") | ||
|
||
draft_model_config.max_model_len = ( | ||
SpeculativeConfig._maybe_override_draft_max_model_len( | ||
speculative_max_model_len, | ||
|
@@ -918,6 +936,12 @@ def maybe_create_spec_config( | |
SpeculativeConfig.create_draft_parallel_config( | ||
target_parallel_config)) | ||
|
||
if num_speculative_tokens is None: | ||
raise ValueError( | ||
"num_speculative_tokens must be provided with " | ||
"speculative_model unless the draft model config contains an " | ||
"n_predict parameter.") | ||
|
||
return SpeculativeConfig( | ||
draft_model_config, | ||
draft_parallel_config, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import math | ||
from typing import Iterable, List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor import SamplingMetadata | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.sequence import SamplerOutput | ||
|
||
|
||
class MLPSpeculatorLayerNorm(nn.Module): | ||
""" | ||
A L2 normalization implementation | ||
... | ||
Args | ||
---- | ||
normalized_shape : int | ||
Dimensionality of input data (size of final tensor axis) | ||
eps : float | ||
Safety term to prevent division by zero. Make sure the chosen value | ||
fits in the range of your encoding scheme | ||
(i.e. fp16 requires eps >= 6e-8). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
normalized_shape, | ||
eps=1e-06, | ||
): | ||
super(MLPSpeculatorLayerNorm, self).__init__() | ||
self.weight = nn.Parameter(torch.empty(normalized_shape)) | ||
self.bias = nn.Parameter(torch.empty(normalized_shape)) | ||
self.eps = eps | ||
|
||
def forward(self, x): | ||
xf = x | ||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) | ||
x = xf.type_as(x) | ||
x = self.weight * x | ||
x = x + self.bias | ||
return x | ||
|
||
|
||
class MLPSpeculator(nn.Module): | ||
|
||
def __init__(self, config, **kwargs) -> None: | ||
super().__init__() | ||
self.n_predict = config.n_predict | ||
self.vocab_size = config.vocab_size | ||
self.emb_dim = config.emb_dim | ||
self.inner_dim = config.inner_dim if config.inner_dim != 0 \ | ||
else config.emb_dim | ||
|
||
self.max_speculative_tokens = getattr(config, "max_speculative_tokens", | ||
self.n_predict) | ||
|
||
self.emb = nn.ModuleList([ | ||
VocabParallelEmbedding(config.vocab_size, | ||
self.inner_dim, | ||
org_num_embeddings=config.vocab_size) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.proj = nn.ModuleList([ | ||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), | ||
self.inner_dim, | ||
bias=False) for i in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.head = nn.ModuleList([ | ||
nn.Linear(self.inner_dim, self.vocab_size, bias=False) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
self.ln = nn.ModuleList([ | ||
MLPSpeculatorLayerNorm(self.inner_dim) | ||
for _ in range(self.max_speculative_tokens) | ||
]) | ||
|
||
self.state_weight = 0.5**(0.5 / config.n_predict) | ||
self.emb_weight = math.sqrt( | ||
(1 - self.state_weight**2) * (self.inner_dim / 2)) | ||
self.activation = nn.GELU() | ||
self.config = config | ||
self.logits_processor = LogitsProcessor(config.vocab_size, | ||
config.vocab_size, 1.0) | ||
self.sampler = Sampler() | ||
|
||
def generate_proposals( | ||
self, | ||
input_ids: torch.Tensor, | ||
previous_hidden_states: torch.Tensor, | ||
num_predict_tokens: int, | ||
sampling_metadata: SamplingMetadata, | ||
) -> List[SamplerOutput]: | ||
if num_predict_tokens > self.max_speculative_tokens: | ||
raise ValueError(f"Max speculative tokens for model is " | ||
f"{self.max_speculative_tokens}, but " | ||
f"{num_predict_tokens} were requested") | ||
|
||
# b x 1 x d | ||
previous_hidden_states = previous_hidden_states.unsqueeze(1) | ||
|
||
# b x 1 | ||
last_tokens = input_ids.unsqueeze(1) | ||
|
||
next_tokens = [] | ||
|
||
for head_index in range(num_predict_tokens): | ||
|
||
# Project and predict | ||
z = self.emb[head_index](last_tokens) # b k d | ||
states = self.proj[head_index](previous_hidden_states) | ||
|
||
# Weighted add of state_weight*state and emb_weight*z | ||
# Let subsequent LN take care of denominator | ||
# state_weight is close to 1, so shouldn't be any precision issues | ||
states.add_(z, alpha=self.emb_weight / self.state_weight) | ||
|
||
states = self.activation(self.ln[head_index](states)) # b k d | ||
# TODO: not yet supporting top_k_tokens_per_head | ||
previous_hidden_states = states | ||
|
||
logits = self.logits_processor(self.head[head_index].weight, | ||
states, sampling_metadata) | ||
|
||
output = self.sampler(logits.flatten(0, 1), sampling_metadata) | ||
last_tokens = output.sampled_token_ids | ||
next_tokens.append(output) | ||
|
||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
param = params_dict[name.replace("speculator.", "")] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive! It seems MLPSpeculator + w/o cudagraph is still much faster than Original + cudagraph.