From 2e83ab1b5d5dbf928c9e47ffc806beb699b836da Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 27 Jun 2024 10:59:33 -0700 Subject: [PATCH] [BugFix] Fix `MLPSpeculator` handling of `num_speculative_tokens` (#5876) --- vllm/config.py | 10 +++++++--- vllm/model_executor/models/mlp_speculator.py | 15 ++++++++------- vllm/transformers_utils/configs/mlp_speculator.py | 3 +++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0c4d770e46847..119cb982f08b4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -920,15 +920,19 @@ def maybe_create_spec_config( max_logprobs=target_model_config.max_logprobs, ) - if (draft_model_config.hf_config.model_type == "mlp_speculator" + draft_hf_config = draft_model_config.hf_config + if (draft_hf_config.model_type == "mlp_speculator" and target_parallel_config.world_size != 1): # MLPSpeculator TP support will be added very soon raise ValueError( "Speculative decoding with mlp_speculator models does not " "yet support distributed inferencing (TP > 1).") - n_predict = getattr(draft_model_config.hf_config, "n_predict", - None) + if (num_speculative_tokens is not None + and hasattr(draft_hf_config, "num_lookahead_tokens")): + draft_hf_config.num_lookahead_tokens = num_speculative_tokens + + n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: # Default to max value defined in draft model config. diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index b18269777cd01..6e6b2d8a7edb0 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -11,6 +11,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs import MLPSpeculatorConfig class MLPSpeculatorLayerNorm(nn.Module): @@ -48,7 +49,7 @@ def forward(self, x): class MLPSpeculator(nn.Module): - def __init__(self, config, **kwargs) -> None: + def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: super().__init__() self.n_predict = config.n_predict self.vocab_size = config.vocab_size @@ -56,8 +57,7 @@ def __init__(self, config, **kwargs) -> None: self.inner_dim = config.inner_dim if config.inner_dim != 0 \ else config.emb_dim - self.max_speculative_tokens = getattr(config, "max_speculative_tokens", - self.n_predict) + self.max_speculative_tokens = config.num_lookahead_tokens self.emb = nn.ModuleList([ VocabParallelEmbedding(config.vocab_size, @@ -137,7 +137,8 @@ def generate_proposals( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - param = params_dict[name.replace("speculator.", "")] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict.get(name.replace("speculator.", "")) + if param is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index dd1d92b861b81..e1c1f4a960128 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -35,6 +35,7 @@ def __init__(self, candidate tree. For each candidate branch in the tree, head n produces topk[n] additional sub-branches. + NOTE: This parameter is currently unused. n_candidates: int number of child candidates to create per sequence """ @@ -47,4 +48,6 @@ def __init__(self, self.n_predict = n_predict self.top_k_tokens_per_head = top_k_tokens_per_head self.n_candidates = n_candidates + self.num_lookahead_tokens = n_predict + super().__init__(**kwargs)