Skip to content

Commit

Permalink
[Speculative decoding 6/9] Integrate speculative decoding with LLMEng…
Browse files Browse the repository at this point in the history
…ine (#3894)
  • Loading branch information
cadedaniel authored Apr 16, 2024
1 parent 69e1d2f commit e95cd87
Show file tree
Hide file tree
Showing 31 changed files with 1,347 additions and 407 deletions.
70 changes: 70 additions & 0 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[
{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2,
"max_num_seqs": 2,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [
{
"use_v2_block_manager": False,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"use_v2_block_manager": True,
"num_lookahead_slots": 0,
},
{
"use_v2_block_manager": True,
"num_lookahead_slots": 5,
},
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify that chunked prefill works with BlockManagerV2, with and without
lookahead scheduling.
"""
output_len = 32
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

print('Getting token ids with BlockManagerV1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids with BlockManagerV2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
Expand Down
17 changes: 10 additions & 7 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple

from vllm import SamplingParams
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -31,14 +31,17 @@ def create_dummy_prompt(


def create_seq_group(
seq_prompt_len=1024,
seq_output_lens=(128, ),
request_id='0',
seq_id_start=0,
) -> SequenceGroup:
seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ),
request_id: str = '0',
seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

assert len(seq_output_lens) > 0

if sampling_params is None:
sampling_params = SamplingParams()

prompt_token_ids = [0] * seq_prompt_len

seqs = []
Expand All @@ -60,7 +63,7 @@ def create_seq_group(
seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
sampling_params=SamplingParams(),
sampling_params=sampling_params,
arrival_time=time.time(),
)

Expand Down
270 changes: 270 additions & 0 deletions tests/engine/output_processor/test_multi_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import random
from unittest.mock import MagicMock

import pytest
from transformers import PreTrainedTokenizer

from tests.core.utils import create_seq_group
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter


@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [1, 12])
@pytest.mark.skip_global_cleanup
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
"""Verify multi-step decoding appends token ids correctly.
We append token ids and verify all the token ids were appended correctly.
Note that ignore_eos=True.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=1024,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=seq_output_len +
num_new_tokens,
ignore_eos=True),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
output_processor.process_outputs(seq_group, outputs)
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
@pytest.mark.parametrize("max_tokens", [128 + 3])
@pytest.mark.skip_global_cleanup
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, max_tokens: int):
"""Verify tokens after max_tokens are dropped and not appended to the
sequence.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=max_tokens, ),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to not go over max tokens in len.
assert seq.get_len() == seq_prompt_len + max_tokens

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""Verify the eos token id is included in the sequence, but subsequent
tokens are dropped (not appended to sequence).
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

eos_token_id = 100

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens, ),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to not go beyond provided eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:eos_index + 1]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""When sampling parameters dictate that we should ignore the eos token id,
ensure all token ids are appended even if the eos token id is emitted.
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

eos_token_id = 100

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens,
ignore_eos=True,
),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to go beyond eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


def mock_tokenizer(eos_token_id=1000):
tokenizer = MagicMock(spec=PreTrainedTokenizer)
tokenizer.eos_token_id = eos_token_id
return tokenizer
Loading

0 comments on commit e95cd87

Please sign in to comment.