Skip to content

Commit 5757d90

Browse files
[Speculative decoding] Adding configuration object for speculative decoding (#3706)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
1 parent a3c226e commit 5757d90

12 files changed

+394
-61
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
from tests.conftest import cleanup
4+
from vllm import LLM
5+
from vllm.model_executor.utils import set_random_seed
6+
7+
8+
@pytest.fixture
9+
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
10+
baseline_llm_kwargs, seed):
11+
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
12+
baseline_llm_kwargs, seed)
13+
14+
15+
@pytest.fixture
16+
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
17+
test_llm_kwargs, seed):
18+
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
19+
test_llm_kwargs, seed)
20+
21+
22+
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
23+
distinct_llm_kwargs, seed):
24+
kwargs = {
25+
**common_llm_kwargs,
26+
**per_test_common_llm_kwargs,
27+
**distinct_llm_kwargs,
28+
}
29+
30+
def generator_inner():
31+
llm = LLM(**kwargs)
32+
33+
set_random_seed(seed)
34+
35+
yield llm
36+
del llm
37+
cleanup()
38+
39+
for llm in generator_inner():
40+
yield llm
41+
del llm
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from vllm import SamplingParams
4+
5+
6+
@pytest.mark.parametrize(
7+
"common_llm_kwargs",
8+
[{
9+
# Use a small model for a fast test.
10+
"model": "facebook/opt-125m",
11+
"speculative_model": "facebook/opt-125m",
12+
"num_speculative_tokens": 5,
13+
14+
# Required for spec decode.
15+
"use_v2_block_manager": True
16+
}])
17+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
18+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
19+
@pytest.mark.parametrize("seed", [1])
20+
def test_spec_decode_config(test_llm_generator):
21+
output_len = 1024
22+
temperature = 0.0
23+
24+
prompts = [
25+
"Hello, my name is",
26+
"The president of the United States is",
27+
"The capital of France is",
28+
"The future of AI is",
29+
]
30+
31+
sampling_params = SamplingParams(
32+
max_tokens=output_len,
33+
ignore_eos=True,
34+
temperature=temperature,
35+
)
36+
37+
with pytest.raises(
38+
AssertionError,
39+
match="Speculative decoding not yet supported for GPU backend"):
40+
get_token_ids_from_llm_generator(test_llm_generator, prompts,
41+
sampling_params)
42+
43+
44+
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
45+
for llm in llm_generator:
46+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
47+
token_ids = [output.outputs[0].token_ids for output in outputs]
48+
del llm
49+
50+
return token_ids

tests/spec_decode/utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,16 @@ def create_worker(cls: type,
107107
block_size=block_size,
108108
enforce_eager=enforce_eager,
109109
)
110-
111-
(model_config, cache_config, parallel_config, scheduler_config,
112-
device_config, _, _) = engine_args.create_engine_configs()
110+
engine_config = engine_args.create_engine_config()
113111

114112
distributed_init_method = get_distributed_init_method(
115113
get_ip(), get_open_port())
116114

117115
worker = cls(
118-
model_config=model_config,
119-
parallel_config=parallel_config,
120-
scheduler_config=scheduler_config,
121-
device_config=device_config,
116+
model_config=engine_config.model_config,
117+
parallel_config=engine_config.parallel_config,
118+
scheduler_config=engine_config.scheduler_config,
119+
device_config=engine_config.device_config,
122120
local_rank=0,
123121
rank=0,
124122
distributed_init_method=distributed_init_method,
@@ -128,9 +126,9 @@ def create_worker(cls: type,
128126
worker.init_device()
129127
worker.load_model()
130128

131-
cache_config.num_gpu_blocks = num_gpu_blocks
132-
cache_config.num_cpu_blocks = 0
133-
worker.init_cache_engine(cache_config)
129+
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
130+
engine_config.cache_config.num_cpu_blocks = 0
131+
worker.init_cache_engine(engine_config.cache_config)
134132
worker.warm_up_model()
135133

136134
return worker

tests/worker/test_swap.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ def test_swap() -> None:
1010
engine_args = EngineArgs(model="facebook/opt-125m",
1111
dtype="half",
1212
load_format="dummy")
13-
(model_config, cache_config, parallel_config, scheduler_config,
14-
device_config, _, _) = engine_args.create_engine_configs()
15-
cache_config.num_gpu_blocks = 100
16-
cache_config.num_cpu_blocks = 100
13+
engine_config = engine_args.create_engine_config()
14+
engine_config.cache_config.num_gpu_blocks = 100
15+
engine_config.cache_config.num_cpu_blocks = 100
1716

1817
# Create the worker.
1918
distributed_init_method = get_distributed_init_method(
2019
get_ip(), get_open_port())
2120
worker = Worker(
22-
model_config=model_config,
23-
parallel_config=parallel_config,
24-
scheduler_config=scheduler_config,
25-
device_config=device_config,
21+
model_config=engine_config.model_config,
22+
parallel_config=engine_config.parallel_config,
23+
scheduler_config=engine_config.scheduler_config,
24+
device_config=engine_config.device_config,
2625
local_rank=0,
2726
rank=0,
2827
distributed_init_method=distributed_init_method,
@@ -32,7 +31,7 @@ def test_swap() -> None:
3231
# Initialize the worker.
3332
worker.init_device()
3433
worker.load_model()
35-
worker.init_cache_engine(cache_config)
34+
worker.init_cache_engine(engine_config.cache_config)
3635
worker.warm_up_model()
3736

3837
# Randomly initialize the cache.

vllm/config.py

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
import json
33
import os
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, fields
55
from typing import TYPE_CHECKING, ClassVar, Optional, Union
66

77
import torch
@@ -617,6 +617,159 @@ def __init__(self, device: str = "auto") -> None:
617617
self.device = torch.device(self.device_type)
618618

619619

620+
class SpeculativeConfig:
621+
"""Configuration for speculative decoding.
622+
623+
The configuration is currently specialized to draft-model speculative
624+
decoding with top-1 proposals.
625+
"""
626+
627+
@staticmethod
628+
def maybe_create_spec_config(
629+
target_model_config: ModelConfig,
630+
target_parallel_config: ParallelConfig,
631+
target_dtype: str,
632+
speculative_model: Optional[str],
633+
num_speculative_tokens: Optional[int],
634+
) -> Optional["SpeculativeConfig"]:
635+
"""Create a SpeculativeConfig if possible, else return None.
636+
637+
This function attempts to create a SpeculativeConfig object based on the
638+
provided parameters. If the necessary conditions are met, it returns an
639+
instance of SpeculativeConfig. Otherwise, it returns None.
640+
641+
Args:
642+
target_model_config (ModelConfig): The configuration of the target
643+
model.
644+
target_parallel_config (ParallelConfig): The parallel configuration
645+
for the target model.
646+
target_dtype (str): The data type used for the target model.
647+
speculative_model (Optional[str]): The name of the speculative
648+
model, if provided.
649+
num_speculative_tokens (Optional[int]): The number of speculative
650+
tokens, if provided.
651+
652+
Returns:
653+
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
654+
the necessary conditions are met, else None.
655+
"""
656+
657+
if (speculative_model is None and num_speculative_tokens is None):
658+
return None
659+
660+
if speculative_model is not None and num_speculative_tokens is None:
661+
raise ValueError(
662+
"Expected both speculative_model and "
663+
"num_speculative_tokens to be provided, but found "
664+
f"{speculative_model=} and {num_speculative_tokens=}.")
665+
666+
# TODO: The user should be able to specify revision/quantization/max
667+
# model len for the draft model. It is not currently supported.
668+
draft_revision = None
669+
draft_code_revision = None
670+
draft_quantization = None
671+
draft_max_model_len = None
672+
673+
draft_model_config = ModelConfig(
674+
model=speculative_model,
675+
tokenizer=target_model_config.tokenizer,
676+
tokenizer_mode=target_model_config.tokenizer_mode,
677+
trust_remote_code=target_model_config.trust_remote_code,
678+
download_dir=target_model_config.download_dir,
679+
load_format=target_model_config.load_format,
680+
dtype=target_model_config.dtype,
681+
seed=target_model_config.seed,
682+
revision=draft_revision,
683+
code_revision=draft_code_revision,
684+
tokenizer_revision=target_model_config.tokenizer_revision,
685+
max_model_len=draft_max_model_len,
686+
quantization=draft_quantization,
687+
enforce_eager=target_model_config.enforce_eager,
688+
max_context_len_to_capture=target_model_config.
689+
max_context_len_to_capture,
690+
max_logprobs=target_model_config.max_logprobs,
691+
)
692+
693+
draft_parallel_config = (
694+
SpeculativeConfig.create_draft_parallel_config(
695+
target_parallel_config))
696+
697+
return SpeculativeConfig(
698+
draft_model_config,
699+
draft_parallel_config,
700+
num_speculative_tokens,
701+
)
702+
703+
@staticmethod
704+
def create_draft_parallel_config(
705+
target_parallel_config: ParallelConfig) -> ParallelConfig:
706+
"""Create a parallel config for use by the draft worker.
707+
708+
This is mostly a copy of the target parallel config. In the future the
709+
draft worker can have a different parallel strategy, e.g. TP=1.
710+
"""
711+
draft_parallel_config = ParallelConfig(
712+
pipeline_parallel_size=target_parallel_config.
713+
pipeline_parallel_size,
714+
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
715+
worker_use_ray=target_parallel_config.worker_use_ray,
716+
max_parallel_loading_workers=target_parallel_config.
717+
max_parallel_loading_workers,
718+
disable_custom_all_reduce=target_parallel_config.
719+
disable_custom_all_reduce,
720+
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
721+
ray_workers_use_nsight=target_parallel_config.
722+
ray_workers_use_nsight,
723+
placement_group=target_parallel_config.placement_group,
724+
)
725+
726+
return draft_parallel_config
727+
728+
def __init__(
729+
self,
730+
draft_model_config: ModelConfig,
731+
draft_parallel_config: ParallelConfig,
732+
num_speculative_tokens: int,
733+
):
734+
"""Create a SpeculativeConfig object.
735+
736+
Args:
737+
draft_model_config: ModelConfig for the draft model.
738+
draft_parallel_config: ParallelConfig for the draft model.
739+
num_speculative_tokens: The number of tokens to sample from the
740+
draft model before scoring with the target model.
741+
"""
742+
self.draft_model_config = draft_model_config
743+
self.draft_parallel_config = draft_parallel_config
744+
self.num_speculative_tokens = num_speculative_tokens
745+
746+
self._verify_args()
747+
748+
def _verify_args(self) -> None:
749+
if self.num_speculative_tokens <= 0:
750+
raise ValueError("Expected num_speculative_tokens to be greater "
751+
f"than zero ({self.num_speculative_tokens}).")
752+
753+
if self.draft_model_config:
754+
self.draft_model_config.verify_with_parallel_config(
755+
self.draft_parallel_config)
756+
757+
@property
758+
def num_lookahead_slots(self) -> int:
759+
"""The number of additional slots the scheduler should allocate per
760+
step, in addition to the slots allocated for each known token.
761+
762+
This is equal to the number of speculative tokens, as each speculative
763+
token must be scored.
764+
"""
765+
return self.num_speculative_tokens
766+
767+
def __repr__(self) -> str:
768+
draft_model = self.draft_model_config.model
769+
num_spec_tokens = self.num_speculative_tokens
770+
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
771+
772+
620773
@dataclass
621774
class LoRAConfig:
622775
max_lora_rank: int
@@ -838,3 +991,36 @@ def _get_and_verify_max_len(
838991
"to incorrect model outputs or CUDA errors. Make sure the "
839992
"value is correct and within the model context size.")
840993
return int(max_model_len)
994+
995+
996+
@dataclass(frozen=True)
997+
class EngineConfig:
998+
"""Dataclass which contains all engine-related configuration. This
999+
simplifies passing around the distinct configurations in the codebase.
1000+
"""
1001+
1002+
model_config: ModelConfig
1003+
cache_config: CacheConfig
1004+
parallel_config: ParallelConfig
1005+
scheduler_config: SchedulerConfig
1006+
device_config: DeviceConfig
1007+
lora_config: Optional[LoRAConfig]
1008+
vision_language_config: Optional[VisionLanguageConfig]
1009+
speculative_config: Optional[SpeculativeConfig]
1010+
1011+
def __post_init__(self):
1012+
"""Verify configs are valid & consistent with each other.
1013+
"""
1014+
self.model_config.verify_with_parallel_config(self.parallel_config)
1015+
self.cache_config.verify_with_parallel_config(self.parallel_config)
1016+
1017+
if self.lora_config:
1018+
self.lora_config.verify_with_model_config(self.model_config)
1019+
self.lora_config.verify_with_scheduler_config(
1020+
self.scheduler_config)
1021+
1022+
def to_dict(self):
1023+
"""Return the configs as a dictionary, for use in **kwargs.
1024+
"""
1025+
return dict(
1026+
(field.name, getattr(self, field.name)) for field in fields(self))

0 commit comments

Comments
 (0)