Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] MLPSpeculator speculative decoding support #4947

Merged
merged 42 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
efb0599
initial commit of mlp_speculator and hidden_states_worker to support …
JRosenkranz May 20, 2024
7a8eeff
Merge branch 'main' into mlp_speculator
JRosenkranz May 20, 2024
667ef88
removed fms_extras import
JRosenkranz May 20, 2024
d534ef2
updated with a working non-batch version - a lot hardcoded
JRosenkranz May 21, 2024
17541b6
updated experimental with working version - eager
JRosenkranz May 22, 2024
ac5a1da
fixed bug with speculator outputs
JRosenkranz May 22, 2024
6ba9a1e
removed comments; swapped to sampling in the example
JRosenkranz May 22, 2024
cb3aacf
Introduce MLPSpeculatorWorker and corresponding refactor
tdoublep May 27, 2024
bf2f102
Fix some issues with correctness + simplify API a bit
tdoublep May 27, 2024
6af4629
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill May 31, 2024
e0309a6
Fix typing and formatting
njhill May 31, 2024
abd42e7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 1, 2024
314f2ae
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 5, 2024
9dd1c50
Remove separate MLPSpeculatorModelRunner and other cleanup
njhill Jun 5, 2024
0d43097
Use sample_len in mlp_speculator
njhill Jun 5, 2024
9dd1608
Some more rework/simplification, still in progress
njhill Jun 6, 2024
ea677bd
Config cleanup
njhill Jun 7, 2024
b39c94f
Ignore weird mypi error only happening in CI
njhill Jun 7, 2024
ab96c2a
Try again to ignore weird ruff error
njhill Jun 7, 2024
e9af7e5
Try to ignore both ruff and mypy errs
njhill Jun 7, 2024
30dc5e6
yapf
njhill Jun 7, 2024
3a61052
Fix leftover HiddenStatesWorker references
njhill Jun 7, 2024
cc05972
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 7, 2024
693974e
Fix AutoConfig import, mlp spec worker docstring
njhill Jun 7, 2024
f1bafba
Some cleanup/simplification
njhill Jun 7, 2024
455b9a9
Rework handling of accepted tokens
njhill Jun 7, 2024
e583ae9
Filter hidden states in Top1Proposer when needed
njhill Jun 9, 2024
7bff0d1
Enable bonus token
njhill Jun 9, 2024
3d04037
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 9, 2024
bea97d7
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 11, 2024
3012553
Move hidden state logic to separate class
njhill Jun 11, 2024
b116e02
Default num_speculative_tokens based on speculator model config
njhill Jun 15, 2024
e7742e7
Move offline_inference example to separate file
njhill Jun 15, 2024
ee83331
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 15, 2024
444a709
ruff
njhill Jun 15, 2024
bb9fd32
Add comment per review
njhill Jun 15, 2024
fcc6606
Some simplification to MLPSpeculatorWorker._prepare_input_tensors
njhill Jun 15, 2024
ffc0bcf
Merge remote-tracking branch 'refs/remotes/origin/main' into mlp_spec…
njhill Jun 17, 2024
f3dc40a
Add check for TP == 1; TP support will be a fast-follow
njhill Jun 17, 2024
1b7e305
Fix test import
njhill Jun 17, 2024
46ceacd
Revert unrelated commit made by mistake
njhill Jun 17, 2024
d9ce339
Fix test mocks
njhill Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from transformers import AutoConfig

from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorConfig
from vllm import LLM, SamplingParams
AutoConfig.register("mlp_speculator", MLPSpeculatorConfig)
njhill marked this conversation as resolved.
Show resolved Hide resolved

template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
# "The president of the United States is",
njhill marked this conversation as resolved.
Show resolved Hide resolved
# "The capital of France is",
# "The future of AI is",
]
prompts = [template.format(prompt) for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="ibm-granite/granite-7b-instruct", use_v2_block_manager=True, enforce_eager=True, speculative_model="ibm-granite/granite-7b-instruct-accelerator", num_speculative_tokens=5)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
import time
outputs = llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
print((end-start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
# Print the outputs.
for output in outputs:
prompt = output.prompt
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def _verify_args(self) -> None:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")

if self.draft_model_config:
if self.draft_model_config and self.draft_model_config.hf_config.model_type != "mlp_speculator":
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
Expand Down
135 changes: 135 additions & 0 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Optional, List, Iterable, Tuple

import torch.nn as nn
import torch
import math
from vllm.attention import AttentionMetadata
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.sequence import SamplerOutput


class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
njhill marked this conversation as resolved.
Show resolved Hide resolved
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
njhill marked this conversation as resolved.
Show resolved Hide resolved
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""

def __init__(
self,
normalized_shape,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps

def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x

class MLPSpeculator(nn.Module):
def __init__(
self,
config,
**kwargs
) -> None:
super().__init__()
self.current_head_index = 0
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size)
for _ in range(config.n_predict)
])

self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), self.inner_dim, bias=False)
for i in range(config.n_predict)
])

self.head = nn.ModuleList([nn.Linear(self.inner_dim, self.vocab_size, bias=False) for _ in range(config.n_predict)])
self.ln = nn.ModuleList([MLPSpeculatorLayerNorm(self.inner_dim) for _ in range(config.n_predict)])
njhill marked this conversation as resolved.
Show resolved Hide resolved

self.state_weight = 0.5 ** (0.5 / config.n_predict)
self.emb_weight = math.sqrt((1 - self.state_weight ** 2) * (self.inner_dim / 2))
self.activation = nn.GELU()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# prune the hidden states
if self.current_head_index == 0:
if self.first_decode_step:
self.first_decode_step = False
else:
self.previous_hidden_state = self.previous_hidden_state.reshape(-1, self.n_predict + 1, self.previous_hidden_state.size(1))
self.previous_hidden_state = self.previous_hidden_state.gather(
1,
(self.accepted_token_lengths - 1)[:, None, None].expand(-1, 1, self.previous_hidden_state.size(2))
).squeeze(1) # b x d

# Project and predict
z = self.emb[self.current_head_index](input_ids[-1]) # b k d
state = self.proj[self.current_head_index](self.previous_hidden_state)
# Weighted add of state_weight*state and emb_weight*z
# Let subsequent LN take care of denominator
# state_weight is close to 1, so shouldn't be any precision issues
state = torch.add(state, z, alpha=self.emb_weight / self.state_weight)
state = self.activation(self.ln[self.current_head_index](state)) # b k d

# todo: not yet supporting top_k_tokens_per_head

self.previous_hidden_state = state
self.current_head_index += 1
return state

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
current_head_index = self.current_head_index - 1
logits = self.logits_processor(self.head[current_head_index].weight, hidden_states,
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
78 changes: 78 additions & 0 deletions vllm/spec_decode/hidden_states_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List, Optional

from vllm.sequence import SequenceGroupMetadata, ExecuteModelRequest, SamplerOutput
from vllm.worker.worker import Worker
import torch

class HiddenStatesWorker(Worker):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.speculator = None
self.prev_request_context_lengths = {}

def _get_hidden_states(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
):

(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.model_runner.prepare_input_tensors(seq_group_metadata_list)

if self.model_runner.lora_config:
self.model_runner.set_active_loras(lora_requests, lora_mapping)

# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.model_runner.graph_runners[graph_batch_size]
else:
model_executable = self.model_runner.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})

# save the previous hidden states for later use
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
logits = self.model_runner.model.compute_logits(hidden_states, sampling_metadata)

# Only perform sampling in the driver worker.
if not self.model_runner.is_driver_worker:
return None

# Sample the next token.
output = self.model_runner.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)

return output, hidden_states


@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:

# reset the head to call in speculator
self.speculator.current_head_index = 0

sampler_output, hidden_states = self._get_hidden_states(execute_model_req.seq_group_metadata_list, self.gpu_cache)

# if we are executing the prompt, we need to flag the first decode step since pruning is handled differently
if execute_model_req.seq_group_metadata_list[0].is_prompt:
self.speculator.first_decode_step = True
self.speculator.previous_hidden_state = hidden_states
return [sampler_output]
8 changes: 6 additions & 2 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.model_runner import SingleStepSpeculativeModelRunner
from vllm.worker.worker import Worker


Expand All @@ -28,6 +29,7 @@ def __init__(self, *args, **kwargs):

# Lazy initialization list.
self._proposer: Top1Proposer
self.requires_kv_cache: bool

def init_device(self):
super().init_device()
Expand All @@ -39,6 +41,8 @@ def init_device(self):
max_proposal_len=self.max_model_len,
)

self.requires_kv_cache = not isinstance(self.model_runner, SingleStepSpeculativeModelRunner)

def set_include_gpu_probs_tensor(self):
# Need include_gpu_probs_tensor for multi_step_worker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
Expand Down Expand Up @@ -66,8 +70,8 @@ def sampler_output(
copied_seq_group_metadata_list)

# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
if self.requires_kv_cache:
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, sample_len)

# Run model sample_len times.
model_outputs = []
Expand Down
31 changes: 21 additions & 10 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.hidden_states_worker import HiddenStatesWorker
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
Expand All @@ -32,7 +33,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config = kwargs.get("speculative_config")
assert speculative_config is not None

target_worker = Worker(*args, **kwargs)
if speculative_config.draft_model_config.hf_config.model_type == "mlp_speculator":
target_worker = HiddenStatesWorker(*args, **kwargs)
else:
target_worker = Worker(*args, **kwargs)

draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
Expand Down Expand Up @@ -165,6 +169,8 @@ def init_device(self) -> None:
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
if isinstance(self.scorer_worker, HiddenStatesWorker):
self.scorer_worker.speculator = self.proposer_worker.model_runner.model

self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
Expand Down Expand Up @@ -212,23 +218,27 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.determine_num_available_blocks())

scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())
if not isinstance(self.scorer_worker, HiddenStatesWorker):
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())

new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)

return num_gpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,

if not isinstance(self.scorer_worker, HiddenStatesWorker):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

def _broadcast_control_flow_decision(
Expand Down Expand Up @@ -291,6 +301,7 @@ def execute_model(
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
disable_all_speculation = disable_all_speculation or isinstance(self.scorer_worker, HiddenStatesWorker)
njhill marked this conversation as resolved.
Show resolved Hide resolved
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)

Expand Down
Loading
Loading