Skip to content

Commit 2ccd13a

Browse files
benchislettWoosukKwon
authored andcommitted
[Spec Decode] Make EAGLE3 draft token ID mapping optional (vllm-project#18488)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 66fb733 commit 2ccd13a

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def compute_logits(
214214
) -> Optional[torch.Tensor]:
215215
logits = self.logits_processor(self.lm_head, hidden_states,
216216
sampling_metadata)
217+
if self.draft_id_to_target_id is None:
218+
return logits
219+
217220
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
218221
targets = base + self.draft_id_to_target_id
219222
logits_new = logits.new_full((
@@ -246,4 +249,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
246249
name = "model." + name
247250
model_weights[name] = loaded_weight
248251

249-
return loader.load_weights(model_weights.items())
252+
loaded_weights = loader.load_weights(model_weights.items())
253+
254+
if 'd2t' not in loaded_weights:
255+
self.draft_id_to_target_id = None
256+
257+
return loaded_weights

0 commit comments

Comments
 (0)