Skip to content

Commit 7bc7532

Browse files
committed
spec decode config
1 parent 93deb0b commit 7bc7532

12 files changed

+393
-60
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
from tests.conftest import cleanup
4+
from vllm import LLM
5+
from vllm.model_executor.utils import set_random_seed
6+
7+
8+
@pytest.fixture
9+
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
10+
baseline_llm_kwargs, seed):
11+
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
12+
baseline_llm_kwargs, seed)
13+
14+
15+
@pytest.fixture
16+
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
17+
test_llm_kwargs, seed):
18+
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
19+
test_llm_kwargs, seed)
20+
21+
22+
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
23+
distinct_llm_kwargs, seed):
24+
kwargs = {
25+
**common_llm_kwargs,
26+
**per_test_common_llm_kwargs,
27+
**distinct_llm_kwargs,
28+
}
29+
30+
def generator_inner():
31+
llm = LLM(**kwargs)
32+
33+
set_random_seed(seed)
34+
35+
yield llm
36+
del llm
37+
cleanup()
38+
39+
for llm in generator_inner():
40+
yield llm
41+
del llm
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from vllm import SamplingParams
4+
5+
6+
@pytest.mark.parametrize(
7+
"common_llm_kwargs",
8+
[{
9+
# Use a small model for a fast test.
10+
"model": "facebook/opt-125m",
11+
"speculative_model": "facebook/opt-125m",
12+
"num_speculative_tokens": 5,
13+
14+
# Required for spec decode.
15+
"use_v2_block_manager": True
16+
}])
17+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
18+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
19+
@pytest.mark.parametrize("seed", [1])
20+
def test_spec_decode_config(test_llm_generator):
21+
output_len = 1024
22+
temperature = 0.0
23+
24+
prompts = [
25+
"Hello, my name is",
26+
"The president of the United States is",
27+
"The capital of France is",
28+
"The future of AI is",
29+
]
30+
31+
sampling_params = SamplingParams(
32+
max_tokens=output_len,
33+
ignore_eos=True,
34+
temperature=temperature,
35+
)
36+
37+
with pytest.raises(
38+
AssertionError,
39+
match="Speculative decoding not yet supported for GPU backend"):
40+
get_token_ids_from_llm_generator(test_llm_generator, prompts,
41+
sampling_params)
42+
43+
44+
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
45+
for llm in llm_generator:
46+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
47+
token_ids = [output.outputs[0].token_ids for output in outputs]
48+
del llm
49+
50+
return token_ids

tests/spec_decode/utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,16 @@ def create_worker(cls: type,
107107
block_size=block_size,
108108
enforce_eager=enforce_eager,
109109
)
110-
111-
(model_config, cache_config, parallel_config, scheduler_config,
112-
device_config, _, _) = engine_args.create_engine_configs()
110+
engine_config = engine_args.create_engine_config()
113111

114112
distributed_init_method = get_distributed_init_method(
115113
get_ip(), get_open_port())
116114

117115
worker = cls(
118-
model_config=model_config,
119-
parallel_config=parallel_config,
120-
scheduler_config=scheduler_config,
121-
device_config=device_config,
116+
model_config=engine_config.model_config,
117+
parallel_config=engine_config.parallel_config,
118+
scheduler_config=engine_config.scheduler_config,
119+
device_config=engine_config.device_config,
122120
local_rank=0,
123121
rank=0,
124122
distributed_init_method=distributed_init_method,
@@ -128,9 +126,9 @@ def create_worker(cls: type,
128126
worker.init_device()
129127
worker.load_model()
130128

131-
cache_config.num_gpu_blocks = num_gpu_blocks
132-
cache_config.num_cpu_blocks = 0
133-
worker.init_cache_engine(cache_config)
129+
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
130+
engine_config.cache_config.num_cpu_blocks = 0
131+
worker.init_cache_engine(engine_config.cache_config)
134132
worker.warm_up_model()
135133

136134
return worker

tests/worker/test_swap.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ def test_swap() -> None:
1010
engine_args = EngineArgs(model="facebook/opt-125m",
1111
dtype="half",
1212
load_format="dummy")
13-
(model_config, cache_config, parallel_config, scheduler_config,
14-
device_config, _, _) = engine_args.create_engine_configs()
15-
cache_config.num_gpu_blocks = 100
16-
cache_config.num_cpu_blocks = 100
13+
engine_config = engine_args.create_engine_config()
14+
engine_config.cache_config.num_gpu_blocks = 100
15+
engine_config.cache_config.num_cpu_blocks = 100
1716

1817
# Create the worker.
1918
distributed_init_method = get_distributed_init_method(
2019
get_ip(), get_open_port())
2120
worker = Worker(
22-
model_config=model_config,
23-
parallel_config=parallel_config,
24-
scheduler_config=scheduler_config,
25-
device_config=device_config,
21+
model_config=engine_config.model_config,
22+
parallel_config=engine_config.parallel_config,
23+
scheduler_config=engine_config.scheduler_config,
24+
device_config=engine_config.device_config,
2625
local_rank=0,
2726
rank=0,
2827
distributed_init_method=distributed_init_method,
@@ -32,7 +31,7 @@ def test_swap() -> None:
3231
# Initialize the worker.
3332
worker.init_device()
3433
worker.load_model()
35-
worker.init_cache_engine(cache_config)
34+
worker.init_cache_engine(engine_config.cache_config)
3635
worker.warm_up_model()
3736

3837
# Randomly initialize the cache.

vllm/config.py

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
import json
33
import os
4-
from dataclasses import dataclass
4+
from dataclasses import asdict, dataclass
55
from typing import TYPE_CHECKING, ClassVar, Optional, Union
66

77
import torch
@@ -612,6 +612,159 @@ def __init__(self, device: str = "auto") -> None:
612612
self.device = torch.device(self.device_type)
613613

614614

615+
class SpeculativeConfig:
616+
"""Configuration for speculative decoding.
617+
618+
The configuration is currently specialized to draft-model speculative
619+
decoding with top-1 proposals.
620+
"""
621+
622+
@staticmethod
623+
def maybe_create_spec_config(
624+
target_model_config: ModelConfig,
625+
target_parallel_config: ParallelConfig,
626+
target_dtype: str,
627+
speculative_model: Optional[str],
628+
num_speculative_tokens: Optional[int],
629+
) -> Optional["SpeculativeConfig"]:
630+
"""Create a SpeculativeConfig if possible, else return None.
631+
632+
This function attempts to create a SpeculativeConfig object based on the
633+
provided parameters. If the necessary conditions are met, it returns an
634+
instance of SpeculativeConfig. Otherwise, it returns None.
635+
636+
Args:
637+
target_model_config (ModelConfig): The configuration of the target
638+
model.
639+
target_parallel_config (ParallelConfig): The parallel configuration
640+
for the target model.
641+
target_dtype (str): The data type used for the target model.
642+
speculative_model (Optional[str]): The name of the speculative
643+
model, if provided.
644+
num_speculative_tokens (Optional[int]): The number of speculative
645+
tokens, if provided.
646+
647+
Returns:
648+
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
649+
the necessary conditions are met, else None.
650+
"""
651+
652+
if (speculative_model is None and num_speculative_tokens is None):
653+
return None
654+
655+
if speculative_model is not None and num_speculative_tokens is None:
656+
raise ValueError(
657+
"Expected both speculative_model and "
658+
"num_speculative_tokens to be provided, but found "
659+
f"{speculative_model=} and {num_speculative_tokens=}.")
660+
661+
# TODO: The user should be able to specify revision/quantization/max
662+
# model len for the draft model. It is not currently supported.
663+
draft_revision = None
664+
draft_code_revision = None
665+
draft_quantization = None
666+
draft_max_model_len = None
667+
668+
draft_model_config = ModelConfig(
669+
model=speculative_model,
670+
tokenizer=target_model_config.tokenizer,
671+
tokenizer_mode=target_model_config.tokenizer_mode,
672+
trust_remote_code=target_model_config.trust_remote_code,
673+
download_dir=target_model_config.download_dir,
674+
load_format=target_model_config.load_format,
675+
dtype=target_model_config.dtype,
676+
seed=target_model_config.seed,
677+
revision=draft_revision,
678+
code_revision=draft_code_revision,
679+
tokenizer_revision=target_model_config.tokenizer_revision,
680+
max_model_len=draft_max_model_len,
681+
quantization=draft_quantization,
682+
enforce_eager=target_model_config.enforce_eager,
683+
max_context_len_to_capture=target_model_config.
684+
max_context_len_to_capture,
685+
max_logprobs=target_model_config.max_logprobs,
686+
)
687+
688+
draft_parallel_config = (
689+
SpeculativeConfig.create_draft_parallel_config(
690+
target_parallel_config))
691+
692+
return SpeculativeConfig(
693+
draft_model_config,
694+
draft_parallel_config,
695+
num_speculative_tokens,
696+
)
697+
698+
@staticmethod
699+
def create_draft_parallel_config(
700+
target_parallel_config: ParallelConfig) -> ParallelConfig:
701+
"""Create a parallel config for use by the draft worker.
702+
703+
This is mostly a copy of the target parallel config. In the future the
704+
draft worker can have a different parallel strategy, e.g. TP=1.
705+
"""
706+
draft_parallel_config = ParallelConfig(
707+
pipeline_parallel_size=target_parallel_config.
708+
pipeline_parallel_size,
709+
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
710+
worker_use_ray=target_parallel_config.worker_use_ray,
711+
max_parallel_loading_workers=target_parallel_config.
712+
max_parallel_loading_workers,
713+
disable_custom_all_reduce=target_parallel_config.
714+
disable_custom_all_reduce,
715+
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
716+
ray_workers_use_nsight=target_parallel_config.
717+
ray_workers_use_nsight,
718+
placement_group=target_parallel_config.placement_group,
719+
)
720+
721+
return draft_parallel_config
722+
723+
def __init__(
724+
self,
725+
draft_model_config: ModelConfig,
726+
draft_parallel_config: ParallelConfig,
727+
num_speculative_tokens: int,
728+
):
729+
"""Create a SpeculativeConfig object.
730+
731+
Args:
732+
draft_model_config: ModelConfig for the draft model.
733+
draft_parallel_config: ParallelConfig for the draft model.
734+
num_speculative_tokens: The number of tokens to sample from the
735+
draft model before scoring with the target model.
736+
"""
737+
self.draft_model_config = draft_model_config
738+
self.draft_parallel_config = draft_parallel_config
739+
self.num_speculative_tokens = num_speculative_tokens
740+
741+
self._verify_args()
742+
743+
def _verify_args(self) -> None:
744+
if self.num_speculative_tokens < 0:
745+
raise ValueError("Expected num_speculative_tokens to be greater "
746+
f"than zero ({self.num_speculative_tokens}).")
747+
748+
if self.draft_model_config:
749+
self.draft_model_config.verify_with_parallel_config(
750+
self.draft_parallel_config)
751+
752+
@property
753+
def num_lookahead_slots(self) -> int:
754+
"""The number of additional slots the scheduler should allocate per
755+
step, in addition to the slots allocated for each known token.
756+
757+
This is equal to the number of speculative tokens, as each speculative
758+
token must be scored.
759+
"""
760+
return self.num_speculative_tokens
761+
762+
def __repr__(self) -> str:
763+
draft_model = self.draft_model_config.model
764+
num_spec_tokens = self.num_speculative_tokens
765+
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
766+
767+
615768
@dataclass
616769
class LoRAConfig:
617770
max_lora_rank: int
@@ -833,3 +986,36 @@ def _get_and_verify_max_len(
833986
"to incorrect model outputs or CUDA errors. Make sure the "
834987
"value is correct and within the model context size.")
835988
return int(max_model_len)
989+
990+
991+
@dataclass(frozen=True)
992+
class EngineConfig:
993+
"""Dataclass which contains all engine-related configuration. This
994+
simplifies passing around the distinct configurations in the codebase.
995+
"""
996+
997+
model_config: ModelConfig
998+
cache_config: CacheConfig
999+
parallel_config: ParallelConfig
1000+
scheduler_config: SchedulerConfig
1001+
device_config: DeviceConfig
1002+
lora_config: Optional[LoRAConfig]
1003+
vision_language_config: Optional[VisionLanguageConfig]
1004+
speculative_config: Optional[SpeculativeConfig]
1005+
1006+
def __post_init__(self):
1007+
"""Verify configs are valid & consistent with each other.
1008+
"""
1009+
1010+
self.model_config.verify_with_parallel_config(self.parallel_config)
1011+
self.cache_config.verify_with_parallel_config(self.parallel_config)
1012+
1013+
if self.lora_config:
1014+
self.lora_config.verify_with_model_config(self.model_config)
1015+
self.lora_config.verify_with_scheduler_config(
1016+
self.scheduler_config)
1017+
1018+
def as_dict(self):
1019+
"""Return the configs as a dictionary, for use in **kwargs.
1020+
"""
1021+
return asdict(self)

0 commit comments

Comments
 (0)