Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] MLPSpeculator speculative decoding support #4947

Merged
merged 42 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
efb0599
initial commit of mlp_speculator and hidden_states_worker to support …
JRosenkranz May 20, 2024
7a8eeff
Merge branch 'main' into mlp_speculator
JRosenkranz May 20, 2024
667ef88
removed fms_extras import
JRosenkranz May 20, 2024
d534ef2
updated with a working non-batch version - a lot hardcoded
JRosenkranz May 21, 2024
17541b6
updated experimental with working version - eager
JRosenkranz May 22, 2024
ac5a1da
fixed bug with speculator outputs
JRosenkranz May 22, 2024
6ba9a1e
removed comments; swapped to sampling in the example
JRosenkranz May 22, 2024
cb3aacf
Introduce MLPSpeculatorWorker and corresponding refactor
tdoublep May 27, 2024
bf2f102
Fix some issues with correctness + simplify API a bit
tdoublep May 27, 2024
6af4629
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill May 31, 2024
e0309a6
Fix typing and formatting
njhill May 31, 2024
abd42e7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 1, 2024
314f2ae
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 5, 2024
9dd1c50
Remove separate MLPSpeculatorModelRunner and other cleanup
njhill Jun 5, 2024
0d43097
Use sample_len in mlp_speculator
njhill Jun 5, 2024
9dd1608
Some more rework/simplification, still in progress
njhill Jun 6, 2024
ea677bd
Config cleanup
njhill Jun 7, 2024
b39c94f
Ignore weird mypi error only happening in CI
njhill Jun 7, 2024
ab96c2a
Try again to ignore weird ruff error
njhill Jun 7, 2024
e9af7e5
Try to ignore both ruff and mypy errs
njhill Jun 7, 2024
30dc5e6
yapf
njhill Jun 7, 2024
3a61052
Fix leftover HiddenStatesWorker references
njhill Jun 7, 2024
cc05972
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 7, 2024
693974e
Fix AutoConfig import, mlp spec worker docstring
njhill Jun 7, 2024
f1bafba
Some cleanup/simplification
njhill Jun 7, 2024
455b9a9
Rework handling of accepted tokens
njhill Jun 7, 2024
e583ae9
Filter hidden states in Top1Proposer when needed
njhill Jun 9, 2024
7bff0d1
Enable bonus token
njhill Jun 9, 2024
3d04037
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 9, 2024
bea97d7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 11, 2024
3012553
Move hidden state logic to separate class
njhill Jun 11, 2024
b116e02
Default num_speculative_tokens based on speculator model config
njhill Jun 15, 2024
e7742e7
Move offline_inference example to separate file
njhill Jun 15, 2024
ee83331
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 15, 2024
444a709
ruff
njhill Jun 15, 2024
bb9fd32
Add comment per review
njhill Jun 15, 2024
fcc6606
Some simplification to MLPSpeculatorWorker._prepare_input_tensors
njhill Jun 15, 2024
ffc0bcf
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 17, 2024
f3dc40a
Add check for TP == 1; TP support will be a fast-follow
njhill Jun 17, 2024
1b7e305
Fix test import
njhill Jun 17, 2024
46ceacd
Revert unrelated commit made by mistake
njhill Jun 17, 2024
d9ce339
Fix test mocks
njhill Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/offline_inference_mlpspeculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
# Print the outputs.
for output in outputs:
generated_text = output.outputs[0].text
print(f"text: {generated_text!r}")


if __name__ == "__main__":

template = (
"Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n")

# Sample prompts.
prompts = [
"Write about the president of the United States.",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)

# Create an LLM without spec decoding
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")

print("Without speculation")
time_generation(llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
llm = LLM(
model="meta-llama/Llama-2-13b-chat-hf",
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
enforce_eager=True,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive! It seems MLPSpeculator + w/o cudagraph is still much faster than Original + cudagraph.


print("With speculation")
time_generation(llm, prompts, sampling_params)
8 changes: 6 additions & 2 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,9 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]

draft_worker.device = 'cuda'
target_worker.device = 'cuda'
Expand Down Expand Up @@ -497,7 +499,9 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]

draft_worker.device = 'cuda'
target_worker.device = 'cuda'
Expand Down
4 changes: 2 additions & 2 deletions tests/spec_decode/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest

from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len


def test_get_all_seq_ids():
Expand Down
54 changes: 39 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,17 @@ def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_text_config.num_attention_heads
total_num_attention_heads = getattr(self.hf_text_config,
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
Expand Down Expand Up @@ -336,8 +338,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it seems get_num_attention_heads can be 0 sometimes, curious about when will it be the case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our speculator models do not use attention :)


def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
Expand Down Expand Up @@ -813,7 +815,8 @@ def maybe_create_spec_config(
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
Expand All @@ -836,24 +839,18 @@ def maybe_create_spec_config(
the necessary conditions are met, else None.
"""

if speculative_model is None and num_speculative_tokens is None:
if speculative_model is None:
if num_speculative_tokens is not None:
raise ValueError("num_speculative_tokens was provided without "
"speculative_model.")
return None

if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")

assert (speculative_model is not None
and num_speculative_tokens is not None)

if enable_chunked_prefill:
raise ValueError(
"Speculative decoding and chunked prefill are "
Expand Down Expand Up @@ -907,6 +904,27 @@ def maybe_create_spec_config(
max_logprobs=target_model_config.max_logprobs,
)

if (draft_model_config.hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
Expand All @@ -918,6 +936,12 @@ def maybe_create_spec_config(
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))

if num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative_model unless the draft model config contains an "
"n_predict parameter.")

return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
Expand Down
143 changes: 143 additions & 0 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import math
from typing import Iterable, List, Tuple

import torch
import torch.nn as nn

from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput


class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
eps : float
Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8).
"""

def __init__(
self,
normalized_shape,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x


class MLPSpeculator(nn.Module):

def __init__(self, config, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim

self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)

self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
for _ in range(self.max_speculative_tokens)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim,
bias=False) for i in range(self.max_speculative_tokens)
])

self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim)
for _ in range(self.max_speculative_tokens)
])

self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt(
(1 - self.state_weight**2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
config.vocab_size, 1.0)
self.sampler = Sampler()

def generate_proposals(
self,
input_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
f"{num_predict_tokens} were requested")

# b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1)

# b x 1
last_tokens = input_ids.unsqueeze(1)

next_tokens = []

for head_index in range(num_predict_tokens):

# Project and predict
z = self.emb[head_index](last_tokens) # b k d
states = self.proj[head_index](previous_hidden_states)

# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
states.add_(z, alpha=self.emb_weight / self.state_weight)

states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states

logits = self.logits_processor(self.head[head_index].weight,
states, sampling_metadata)

output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids
next_tokens.append(output)

return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading
Loading