Skip to content

Commit

Permalink
[Misc] Add quantization config support for speculative model. (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
ShangmingCai authored and omrishiv committed Aug 26, 2024
1 parent 8700ba9 commit 83beb2e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 4 deletions.
48 changes: 48 additions & 0 deletions tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,51 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
max_output_len=output_len,
force_output_len=True,
)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_speculative_model_quantization_config(baseline_llm_generator,
test_llm_generator,
batch_size: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
12 changes: 8 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ def maybe_create_spec_config(
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
Expand Down Expand Up @@ -989,6 +990,9 @@ def maybe_create_spec_config(
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_model_quantization (Optional[str]): Quantization method
that was used to quantize the speculative model weights. If
None, we assume the model weights are not quantized.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
Expand Down Expand Up @@ -1056,11 +1060,11 @@ def maybe_create_spec_config(
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")

# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_quantization = speculative_model_quantization

if speculative_model == "[ngram]":
if ngram_prompt_lookup_min is None:
Expand Down Expand Up @@ -1217,7 +1221,7 @@ def create_draft_parallel_config(
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")

draft_parallel_config = ParallelConfig(
Expand Down
15 changes: 15 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
Expand Down Expand Up @@ -572,6 +573,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.speculative_model_quantization,
help='Method used to quantize the weights of speculative model.'
'If None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
Expand Down Expand Up @@ -846,6 +859,8 @@ def create_engine_config(self, ) -> EngineConfig:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
Expand Down

0 comments on commit 83beb2e

Please sign in to comment.