Skip to content

Commit

Permalink
[BugFix] Fix MLPSpeculator handling of num_speculative_tokens (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and prashantgupta24 committed Jun 28, 2024
1 parent 610956e commit 2e83ab1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
10 changes: 7 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,15 +920,19 @@ def maybe_create_spec_config(
max_logprobs=target_model_config.max_logprobs,
)

if (draft_model_config.hf_config.model_type == "mlp_speculator"
draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")

n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens

n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig


class MLPSpeculatorLayerNorm(nn.Module):
Expand Down Expand Up @@ -48,16 +49,15 @@ def forward(self, x):

class MLPSpeculator(nn.Module):

def __init__(self, config, **kwargs) -> None:
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim

self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)
self.max_speculative_tokens = config.num_lookahead_tokens

self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
Expand Down Expand Up @@ -137,7 +137,8 @@ def generate_proposals(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict.get(name.replace("speculator.", ""))
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
3 changes: 3 additions & 0 deletions vllm/transformers_utils/configs/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
NOTE: This parameter is currently unused.
n_candidates: int
number of child candidates to create per sequence
"""
Expand All @@ -47,4 +48,6 @@ def __init__(self,
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict

super().__init__(**kwargs)

0 comments on commit 2e83ab1

Please sign in to comment.