Skip to content

Commit 1f496a1

Browse files
committed
Fix draft logprobs zeros bug and add acceptance sanity checks
Bug #1: EAGLE tree proposal returned zeros for draft_logprobs - Root cause: When using topk for tree branching, code set draft_logp_list=None, then created zeros tensor as fallback (lines 850-851) - Fix: Compute actual log-probs from logits using log_softmax + gather - Applied at 2 locations: root level (lines 698-704) and tree levels (lines 839-846) Bug #2: Added diagnostic logging in rejection sampler - Log draft_p (nonzero) min/med/max to detect zeros - Log p_target min/med/max to detect degenerate softmax - Helps identify if target logits are masked/filtered before sampling Expected results after fix: - draft_logp: -3.2/-1.6/-0.0 (real log-probs, all ≤ 0) instead of 0/0/0 - p_target: 1e-6/1e-3/0.7 (realistic distribution) instead of 1/1/1 - Acceptance rate: 30-70% instead of 0% Files changed: - vllm/v1/spec_decode/eagle.py: Fix draft_logp computation - vllm/v1/sample/rejection_sampler.py: Add sanity logging
1 parent 16cdf4f commit 1f496a1

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

vllm/v1/sample/rejection_sampler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ def forward(
124124
sampling_metadata,
125125
)
126126

127+
# Sanity checks: Inspect raw values BEFORE any processing
128+
if draft_probs is not None:
129+
draft_p_nonzero = draft_probs[draft_probs > 0]
130+
if draft_p_nonzero.numel() > 0:
131+
print(f"[SANITY] draft_p (nonzero) min/med/max: "
132+
f"{draft_p_nonzero.min():.3e}/"
133+
f"{draft_p_nonzero.median():.3e}/"
134+
f"{draft_p_nonzero.max():.3e}",
135+
file=sys.stderr, flush=True)
136+
else:
137+
print(f"[SANITY] draft_p: ALL ZEROS!", file=sys.stderr, flush=True)
138+
139+
# Check target probabilities for the chosen draft tokens
140+
target_p_check = target_probs.gather(-1, metadata.draft_token_ids.unsqueeze(-1)).squeeze(-1)
141+
print(f"[SANITY] p_target min/med/max: "
142+
f"{target_p_check.min():.3e}/"
143+
f"{target_p_check.median():.3e}/"
144+
f"{target_p_check.max():.3e}",
145+
file=sys.stderr, flush=True)
146+
127147
output_token_ids = rejection_sample(
128148
metadata.draft_token_ids,
129149
metadata.num_draft_tokens,

vllm/v1/spec_decode/eagle.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -695,10 +695,13 @@ def propose_tree(
695695
draft_token_ids_list = [draft_token_ids]
696696
draft_logp_list = [draft_logp]
697697
else:
698+
# Top-k branching: compute actual log-probs for chosen tokens
698699
draft_token_ids = torch.topk(logits, num_children,
699700
dim=-1).indices.view(batch_size, -1)
701+
log_probs = torch.log_softmax(logits, dim=-1)
702+
draft_logp = log_probs.gather(-1, draft_token_ids).view(batch_size, -1)
700703
draft_token_ids_list = [draft_token_ids]
701-
draft_logp_list = None # No valid draft probs for deterministic top-k
704+
draft_logp_list = [draft_logp]
702705
draft_hidden_states = hidden_states.view(batch_size, 1, -1)
703706

704707
# Initialize empty tensors for concatenation with the level outputs.
@@ -833,23 +836,22 @@ def propose_tree(
833836
draft_token_ids_list.append(draft_token_ids)
834837
draft_logp_list.append(draft_logp)
835838
else:
839+
# Top-k branching: compute actual log-probs for chosen tokens
836840
draft_token_ids = torch.topk(logits, num_children,
837841
dim=-1).indices.view(
838842
batch_size, -1)
843+
log_probs = torch.log_softmax(logits, dim=-1)
844+
draft_logp = log_probs.gather(-1, draft_token_ids).view(batch_size, -1)
839845
draft_token_ids_list.append(draft_token_ids)
840-
if draft_logp_list is not None:
841-
draft_logp_list = None # Mixed modes -> no valid logprobs
846+
draft_logp_list.append(draft_logp)
842847

843848
# Update the # drafts counters for the next tree level.
844849
level_num_drafts = self.cu_drafts_per_level[level +
845850
1] - total_num_drafts
846851
total_num_drafts = self.cu_drafts_per_level[level + 1]
847852

848853
# Return both token IDs and logprobs
849-
# If draft_logp_list is None (top-k was used), return zeros
850-
if draft_logp_list is None:
851-
draft_logp_list = [torch.zeros_like(draft_token_ids_list[0], dtype=torch.float32)]
852-
854+
# All branches now compute real logprobs, no need for fallback
853855
return draft_token_ids_list, draft_logp_list
854856

855857
def prepare_inputs(

0 commit comments

Comments
 (0)