Skip to content

Commit

Permalink
[Dynamic Spec Decoding] Minor fix for disabling speculative decoding (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU authored and joerunde committed Jun 3, 2024
1 parent 91fadd7 commit cf06cc8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 11 deletions.
41 changes: 41 additions & 0 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
16 changes: 10 additions & 6 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
import torch
Expand All @@ -13,9 +13,9 @@
from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
Expand All @@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
num_lookahead_slots=k,
running_queue_size=queue_size)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)

# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
Expand Down
17 changes: 12 additions & 5 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,17 @@ def execute_model(
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)

# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)

Expand Down Expand Up @@ -316,8 +323,8 @@ def _maybe_disable_speculative_tokens(
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to
the proposer and scorer model so that the KV cache is consistent
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
Expand Down

0 comments on commit cf06cc8

Please sign in to comment.