Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Speculative Decoding #1797

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
75ae5fd
spec draft
LiuXiaoxuanPKU Nov 6, 2023
46cd4c3
Merge branch 'vllm-project:main' into spec
LiuXiaoxuanPKU Nov 6, 2023
edeaec0
minor
LiuXiaoxuanPKU Nov 6, 2023
95a7e13
minor
LiuXiaoxuanPKU Nov 8, 2023
366fbb9
draft tokens
LiuXiaoxuanPKU Nov 8, 2023
3c7397e
minor
LiuXiaoxuanPKU Nov 8, 2023
9f35009
merge
LiuXiaoxuanPKU Nov 8, 2023
9b64276
Merge branch 'main' of github.com:LiuXiaoxuanPKU/vllm
LiuXiaoxuanPKU Nov 8, 2023
1525262
Merge branch 'main' into spec
LiuXiaoxuanPKU Nov 8, 2023
7e6224a
minor
LiuXiaoxuanPKU Nov 9, 2023
93901c8
Merge branch 'spec' of github.com:LiuXiaoxuanPKU/vllm into spec
LiuXiaoxuanPKU Nov 9, 2023
692328a
draft logits
LiuXiaoxuanPKU Nov 9, 2023
8b6d647
need to change draft token probs data structure
LiuXiaoxuanPKU Nov 9, 2023
675e1ae
rejection sampling
LiuXiaoxuanPKU Nov 9, 2023
32267f6
rejection sampling
LiuXiaoxuanPKU Nov 10, 2023
1aab040
format
LiuXiaoxuanPKU Nov 12, 2023
826b54a
get draft probs
LiuXiaoxuanPKU Nov 12, 2023
b2ec9aa
style
LiuXiaoxuanPKU Nov 12, 2023
6382396
combine draft_token_ids and output_token_ids in SequenceData
LiuXiaoxuanPKU Nov 13, 2023
89d8ba2
invalidate kv draft
LiuXiaoxuanPKU Nov 13, 2023
9594d08
fix
LiuXiaoxuanPKU Nov 13, 2023
6b1e94c
pass in multiple tokens for generation phase, kv_mqa
LiuXiaoxuanPKU Nov 13, 2023
2d5c379
pass scheduler to spec worker
LiuXiaoxuanPKU Nov 13, 2023
025bb89
mqa
LiuXiaoxuanPKU Nov 15, 2023
dd23ff7
separate sampler
LiuXiaoxuanPKU Nov 15, 2023
f1b3987
lots of fix, multi_qa_kv runnable
LiuXiaoxuanPKU Nov 16, 2023
9a85990
nan in hidden states
LiuXiaoxuanPKU Nov 16, 2023
54bfebd
lots of style fix, early break accepting tokens
LiuXiaoxuanPKU Nov 17, 2023
a904ac9
fix free bug
LiuXiaoxuanPKU Nov 18, 2023
0cb9326
bug fix
LiuXiaoxuanPKU Nov 18, 2023
4e9ae6c
minor fix get target probs in prefill phase
LiuXiaoxuanPKU Nov 18, 2023
0ff36e7
fix mismatch between logical and physical blocks!!
LiuXiaoxuanPKU Nov 24, 2023
d2d67f9
add alphas
LiuXiaoxuanPKU Nov 27, 2023
7d94cb2
tokenizer & bug fix
LiuXiaoxuanPKU Nov 30, 2023
b1a5a88
pass tests
LiuXiaoxuanPKU Nov 30, 2023
93c7956
add flag
LiuXiaoxuanPKU Dec 3, 2023
141da66
remove speculative decoding for prompt run
LiuXiaoxuanPKU Dec 5, 2023
439c88b
remove temperature, only support all greedy for now
LiuXiaoxuanPKU Dec 6, 2023
40ab8d4
clean
Dec 7, 2023
bf2ebe9
minor
Dec 7, 2023
179e968
merge
LiuXiaoxuanPKU Dec 7, 2023
664a256
fix & pass tests
LiuXiaoxuanPKU Dec 7, 2023
7f9a373
format
LiuXiaoxuanPKU Dec 7, 2023
0540142
remove old files
LiuXiaoxuanPKU Dec 7, 2023
993f2d4
remove untouched file
LiuXiaoxuanPKU Dec 8, 2023
c410cbe
format
LiuXiaoxuanPKU Dec 8, 2023
9f2d98b
format
LiuXiaoxuanPKU Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ def run_to_completion(profile_dir: Optional[str] = None):
'with ui.perfetto.dev or Tensorboard.'
))
args = parser.parse_args()
main(args)
main(args)
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,17 @@ def __init__(
model_name: str,
tokenizer_name: Optional[str] = None,
dtype: str = "half",
draft_model: str = None,
propose_cnt: int = 1,
) -> None:
self.model = LLM(
model=model_name,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,
swap_space=0,
draft_model=draft_model,
propose_cnt=propose_cnt,
)

def generate(
Expand Down
50 changes: 50 additions & 0 deletions tests/models/test_spec_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Compare the outputs of Specutiave Decoding and original vLLM

Run `pytest tests/models/test_spec_dec.py --forked`.
"""
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
from vllm.config import FLAGS
import pytest

MODELS = [
"lmsys/vicuna-7b-v1.3",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [50])
@pytest.mark.parametrize("draft_model", ["JackFram/llama-160m"])
@pytest.mark.parametrize("propose_cnt", [5])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
draft_model: str,
propose_cnt: int,
) -> None:
spec_vllm_model = vllm_runner(model,
dtype=dtype,
draft_model=draft_model,
propose_cnt=propose_cnt)
spec_vllm_outputs = spec_vllm_model.generate_greedy(
example_prompts, max_tokens)
del spec_vllm_model
destroy_model_parallel()

FLAGS.ENABLE_SD = False
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

for i in range(len(example_prompts)):
spec_output_ids, spec_output_str = spec_vllm_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert spec_output_str == vllm_output_str, (
f"Test{i}:\nSpec: {len(spec_output_str)}\nvLLM: {len(vllm_output_str)}"
)
assert spec_output_ids == vllm_output_ids, (
f"Test{i}:\nSpec: {len(spec_output_ids)}\nvLLM: {len(vllm_output_ids)}"
)
8 changes: 8 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]

# delete num tokens from the end in the same block
def delete_last_tokens(self, num: int) -> None:
assert num > 0
assert num <= self.num_tokens
self.num_tokens -= num
for i in range(self.num_tokens, len(self.token_ids)):
self.token_ids[i] = _BLANK_TOKEN_ID


class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
Expand Down
12 changes: 12 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
_GB = 1 << 30


class FLAGS:
ENABLE_SD = False


class ModelConfig:
"""Configuration for the model.

Expand Down Expand Up @@ -356,6 +360,14 @@ def _verify_args(self) -> None:
f"({self.max_num_seqs}).")


class SpecDecConfig:

def __init__(self, draft_model_config: ModelConfig,
propose_cnt: int) -> None:
self.draft_model_config = draft_model_config
self.propose_cnt = propose_cnt


_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
Expand Down
9 changes: 9 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number

def free_tailing_blocks(self, seq: Sequence) -> None:
block_table = self.block_tables[seq.seq_id]
free_cnt = len(seq.logical_token_blocks) - len(block_table)
while free_cnt > 0:
block = block_table.pop()
self.gpu_allocator.free(block)
free_cnt -= 1
self.block_tables[seq.seq_id] = block_table

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
Expand Down
24 changes: 23 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
SequenceGroupMetadata, SequenceStatus,
SequenceOutput)

logger = init_logger(__name__)

Expand Down Expand Up @@ -309,6 +310,27 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_invalid_kv(self, seq: Sequence, seq_out: SequenceOutput):
# if all the tokens are accepted
# draft_token_ids: [A, B, C], accepted_tokens: [A, B, C, D], invalid_token_cnt = 3 + 1 - 4 = 0
# if part of the tokens are accepted
# draft_token_ids: [A, B, C], accepted_tokens: [A, B, D], invalid_token_cnt = 3 + 1 - 3 = 1
invalid_token_cnt = len(seq.data.get_draft_token_ids()) + 1 - len(
seq_out.accepted_tokens)
assert invalid_token_cnt >= 0

if invalid_token_cnt == 0:
return invalid_token_cnt

# delete data
seq.data.output_token_ids = seq.data.output_token_ids[:
-invalid_token_cnt]
# delete from logical table
seq.delete_tailing_tokens(invalid_token_cnt)
# delete from physical table
self.block_manager.free_tailing_blocks(seq)
return invalid_token_cnt

def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
Expand Down
37 changes: 34 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Tuple

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, SpecDecConfig)


@dataclass
Expand All @@ -25,7 +25,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
gpu_memory_utilization: float = 0.80
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
Expand All @@ -34,6 +34,9 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None

draft_model: Optional[str] = None
propose_cnt: Optional[int] = None

def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
Expand Down Expand Up @@ -182,6 +185,21 @@ def add_cli_args(
choices=['awq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')

# speculative decoding setting
parser.add_argument(
'--draft-model',
type=str,
default=None,
help=
'name or path of the huggingface model to use as the draft model')
parser.add_argument(
'--propose-cnt',
type=int,
default=5,
help=
'for speculative decoding, number of tokens to propose each step')

return parser

@classmethod
Expand Down Expand Up @@ -213,7 +231,20 @@ def create_engine_configs(
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config

spec_dec_config: SpecDecConfig = None
if self.draft_model:
# assume the draft model and target model share the same tokenizer
# for now, share the same seed as the target
draft_model_config = ModelConfig(self.draft_model, self.tokenizer,
self.tokenizer_mode,
self.trust_remote_code,
self.download_dir,
self.load_format, 'auto',
self.seed)
spec_dec_config = SpecDecConfig(draft_model_config,
self.propose_cnt)
return model_config, cache_config, parallel_config, scheduler_config, spec_dec_config


@dataclass
Expand Down
Loading
Loading