Skip to content

Commit 30e80d3

Browse files
hmellorminpeter
authored andcommitted
Improve configs - SpeculativeConfig (vllm-project#16971)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent f6b992a commit 30e80d3

File tree

2 files changed

+85
-104
lines changed

2 files changed

+85
-104
lines changed

vllm/config.py

Lines changed: 73 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,139 +2128,113 @@ def __post_init__(self):
21282128
self.device = torch.device(self.device_type)
21292129

21302130

2131+
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
2132+
"draft_model"]
2133+
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
2134+
"typical_acceptance_sampler"]
2135+
2136+
2137+
@config
21312138
@dataclass
21322139
class SpeculativeConfig:
2133-
"""
2134-
Configuration for speculative decoding.
2135-
Configurable parameters include:
2136-
- General Speculative Decoding Control:
2137-
- num_speculative_tokens (int): The number of speculative
2138-
tokens, if provided. It will default to the number in the draft
2139-
model config if present, otherwise, it is required.
2140-
- model (Optional[str]): The name of the draft model, eagle head,
2141-
or additional weights, if provided.
2142-
- method (Optional[str]): The name of the speculative method to use.
2143-
If users provide and set the `model` param, the speculative method
2144-
type will be detected automatically if possible, if `model` param
2145-
is not provided, the method name must be provided.
2146-
- Possible values:
2147-
- ngram
2148-
Related additional configuration:
2149-
- prompt_lookup_max (Optional[int]):
2150-
Maximum size of ngram token window when using Ngram
2151-
proposer, required when method is set to ngram.
2152-
- prompt_lookup_min (Optional[int]):
2153-
Minimum size of ngram token window when using Ngram
2154-
proposer, if provided. Defaults to 1.
2155-
- eagle
2156-
- medusa
2157-
- mlp_speculator
2158-
- draft_model
2159-
- acceptance_method (str): The method to use for accepting draft
2160-
tokens. This can take two possible values: 'rejection_sampler' and
2161-
'typical_acceptance_sampler' for RejectionSampler and
2162-
TypicalAcceptanceSampler respectively. If not specified, it
2163-
defaults to 'rejection_sampler'.
2164-
- Possible values:
2165-
- rejection_sampler
2166-
- typical_acceptance_sampler
2167-
Related additional configuration:
2168-
- posterior_threshold (Optional[float]):
2169-
A threshold value that sets a lower bound on the
2170-
posterior probability of a token in the target model
2171-
for it to be accepted. This threshold is used only
2172-
when we use the TypicalAcceptanceSampler for token
2173-
acceptance.
2174-
- posterior_alpha (Optional[float]):
2175-
Scaling factor for entropy-based threshold, applied
2176-
when using TypicalAcceptanceSampler.
2177-
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
2178-
parallelism for the draft model. Can only be 1 or the same as the
2179-
target model's tensor parallel size.
2180-
- disable_logprobs (bool): If set to True, token log probabilities are
2181-
not returned during speculative decoding. If set to False, token
2182-
log probabilities are returned according to the log probability
2183-
settings in SamplingParams. If not specified, it defaults to True.
2184-
2185-
- Draft Model Configuration:
2186-
- quantization (Optional[str]): Quantization method that was used to
2187-
quantize the draft model weights. If None, we assume the
2188-
model weights are not quantized. Note that it only takes effect
2189-
when using the draft model-based speculative method.
2190-
- max_model_len (Optional[int]): The maximum model length of the
2191-
draft model. Used when testing the ability to skip
2192-
speculation for some sequences.
2193-
- revision: The specific model version to use for the draft model. It
2194-
can be a branch name, a tag name, or a commit id. If unspecified,
2195-
will use the default version.
2196-
- code_revision: The specific revision to use for the draft model code
2197-
on Hugging Face Hub. It can be a branch name, a tag name, or a
2198-
commit id. If unspecified, will use the default version.
2140+
"""Configuration for speculative decoding."""
21992141

2200-
- Advanced Control:
2201-
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
2202-
batch expansion for scoring proposals. If not specified, it
2203-
defaults to False.
2204-
- disable_by_batch_size (Optional[int]): Disable speculative decoding
2205-
for new incoming requests when the number of enqueued requests is
2206-
larger than this value, if provided.
2207-
2208-
Although the parameters above are structured hierarchically, there is no
2209-
need to nest them during configuration.
2210-
2211-
Non-configurable internal parameters include:
2212-
- Model Configuration:
2213-
- target_model_config (ModelConfig): The configuration of the target
2214-
model.
2215-
- draft_model_config (ModelConfig): The configuration of the draft
2216-
model initialized internal.
2217-
- Parallelism Configuration:
2218-
- target_parallel_config (ParallelConfig): The parallel configuration
2219-
for the target model.
2220-
- draft_parallel_config (ParallelConfig): The parallel configuration
2221-
for the draft model initialized internal.
2222-
- Execution Control:
2223-
- enable_chunked_prefill (bool): Whether vLLM is configured to use
2224-
chunked prefill or not. Used for raising an error since it's not
2225-
yet compatible with speculative decode.
2226-
- disable_log_stats (bool): Whether to disable the periodic printing of
2227-
stage times in speculative decoding.
2228-
"""
2229-
# speculative configs from cli args
2142+
# General speculative decoding control
22302143
num_speculative_tokens: int = field(default=None,
22312144
init=True) # type: ignore
2232-
method: Optional[str] = None
2233-
acceptance_method: str = "rejection_sampler"
2145+
"""The number of speculative tokens, if provided. It will default to the
2146+
number in the draft model config if present, otherwise, it is required."""
2147+
model: Optional[str] = None
2148+
"""The name of the draft model, eagle head, or additional weights, if
2149+
provided."""
2150+
method: Optional[SpeculativeMethod] = None
2151+
"""The name of the speculative method to use. If users provide and set the
2152+
`model` param, the speculative method type will be detected automatically
2153+
if possible, if `model` param is not provided, the method name must be
2154+
provided.
2155+
2156+
If using `ngram` method, the related configuration `prompt_lookup_max` and
2157+
`prompt_lookup_min` should be considered."""
2158+
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
2159+
"""The method to use for accepting draft tokens:\n
2160+
- "rejection_sampler" maps to `RejectionSampler`.\n
2161+
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
2162+
2163+
If using `typical_acceptance_sampler`, the related configuration
2164+
`posterior_threshold` and `posterior_alpha` should be considered."""
22342165
draft_tensor_parallel_size: Optional[int] = None
2166+
"""The degree of the tensor parallelism for the draft model. Can only be 1
2167+
or the same as the target model's tensor parallel size."""
22352168
disable_logprobs: bool = True
2169+
"""If set to True, token log probabilities are not returned during
2170+
speculative decoding. If set to False, token log probabilities are returned
2171+
according to the log probability settings in SamplingParams."""
22362172

2237-
model: Optional[str] = None
2173+
# Draft model configuration
22382174
quantization: Optional[str] = None
2175+
"""Quantization method that was used to quantize the draft model weights.
2176+
If `None`, we assume the model weights are not quantized. Note that it only
2177+
takes effect when using the draft model-based speculative method."""
22392178
max_model_len: Optional[int] = None
2179+
"""The maximum model length of the draft model. Used when testing the
2180+
ability to skip speculation for some sequences."""
22402181
revision: Optional[str] = None
2182+
"""The specific model version to use for the draft model. It can be a
2183+
branch name, a tag name, or a commit id. If unspecified, will use the
2184+
default version."""
22412185
code_revision: Optional[str] = None
2186+
"""The specific revision to use for the draft model code on Hugging Face
2187+
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
2188+
will use the default version."""
22422189

2190+
# Advanced control
22432191
disable_mqa_scorer: bool = False
2192+
"""Disable the MQA scorer and fall back to batch expansion for scoring
2193+
proposals."""
22442194
disable_by_batch_size: Optional[int] = None
2195+
"""Disable speculative decoding for new incoming requests when the number
2196+
of enqueued requests is larger than this value, if provided."""
2197+
2198+
# Ngram proposer configuration
22452199
prompt_lookup_max: Optional[int] = None
2200+
"""Maximum size of ngram token window when using Ngram proposer, required
2201+
when method is set to ngram."""
22462202
prompt_lookup_min: Optional[int] = None
2203+
"""Minimum size of ngram token window when using Ngram proposer, if
2204+
provided. Defaults to 1."""
2205+
2206+
# Typical acceptance sampler configuration
22472207
posterior_threshold: Optional[float] = None
2208+
"""A threshold value that sets a lower bound on the posterior probability
2209+
of a token in the target model for it to be accepted. This threshold is
2210+
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
2211+
"""
22482212
posterior_alpha: Optional[float] = None
2213+
"""Scaling factor for entropy-based threshold, applied when using
2214+
`TypicalAcceptanceSampler`."""
22492215

22502216
# required configuration params passed from engine
22512217
target_model_config: ModelConfig = field(default=None,
22522218
init=True) # type: ignore
2219+
"""The configuration of the target model."""
22532220
target_parallel_config: ParallelConfig = field(default=None,
22542221
init=True) # type: ignore
2222+
"""The parallel configuration for the target model."""
22552223
enable_chunked_prefill: bool = field(default=None,
22562224
init=True) # type: ignore
2225+
"""Whether vLLM is configured to use chunked prefill or not. Used for
2226+
raising an error since it's not yet compatible with speculative decode."""
22572227
disable_log_stats: bool = field(default=None, init=True) # type: ignore
2228+
"""Whether to disable the periodic printing of stage times in speculative
2229+
decoding."""
22582230

22592231
# params generated in the post-init stage
22602232
draft_model_config: ModelConfig = field(default=None,
22612233
init=True) # type: ignore
2234+
"""The configuration of the draft model initialized internal."""
22622235
draft_parallel_config: ParallelConfig = field(default=None,
22632236
init=True) # type: ignore
2237+
"""The parallel configuration for the draft model initialized internal."""
22642238

22652239
def compute_hash(self) -> str:
22662240
"""

vllm/engine/arg_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -768,11 +768,18 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
768768
help=('Maximum number of forward steps per '
769769
'scheduler call.'))
770770

771-
parser.add_argument('--speculative-config',
772-
type=json.loads,
773-
default=None,
774-
help='The configurations for speculative decoding.'
775-
' Should be a JSON string.')
771+
# Speculative arguments
772+
speculative_group = parser.add_argument_group(
773+
title="SpeculativeConfig",
774+
description=SpeculativeConfig.__doc__,
775+
)
776+
speculative_group.add_argument(
777+
'--speculative-config',
778+
type=json.loads,
779+
default=None,
780+
help='The configurations for speculative decoding.'
781+
' Should be a JSON string.')
782+
776783
parser.add_argument(
777784
'--ignore-patterns',
778785
action="append",

0 commit comments

Comments
 (0)