diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 4a2b62151f8cd..b44d269fa7382 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -42,3 +42,51 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, max_output_len=output_len, force_output_len=True, ) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + # Explicitly specify draft model quantization + { + "speculative_model_quantization": "gptq", + }, + # Explicitly specify GPTQ-based draft model to use marlin quantization + { + "speculative_model_quantization": "marlin", + }, + # Not explicitly specify draft model quantization + { + "speculative_model_quantization": None, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_speculative_model_quantization_config(baseline_llm_generator, + test_llm_generator, + batch_size: int): + """Verify spec decode works well with draft model quantization configs. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) diff --git a/vllm/config.py b/vllm/config.py index b564a0c68cef8..19cd4d8b51d7c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -961,6 +961,7 @@ def maybe_create_spec_config( target_parallel_config: ParallelConfig, target_dtype: str, speculative_model: Optional[str], + speculative_model_quantization: Optional[str], speculative_draft_tensor_parallel_size: Optional[int], num_speculative_tokens: Optional[int], speculative_max_model_len: Optional[int], @@ -989,6 +990,9 @@ def maybe_create_spec_config( target_dtype (str): The data type used for the target model. speculative_model (Optional[str]): The name of the speculative model, if provided. + speculative_model_quantization (Optional[str]): Quantization method + that was used to quantize the speculative model weights. If + None, we assume the model weights are not quantized. speculative_draft_tensor_parallel_size (Optional[int]): The degree of the tensor parallelism for the draft model. num_speculative_tokens (Optional[int]): The number of speculative @@ -1056,11 +1060,11 @@ def maybe_create_spec_config( "Speculative decoding requires usage of the V2 " "block manager. Enable it with --use-v2-block-manager.") - # TODO: The user should be able to specify revision/quantization/max - # model len for the draft model. It is not currently supported. + # TODO: The user should be able to specify revision/max model len + # for the draft model. It is not currently supported. draft_revision = None draft_code_revision = None - draft_quantization = None + draft_quantization = speculative_model_quantization if speculative_model == "[ngram]": if ngram_prompt_lookup_min is None: @@ -1217,7 +1221,7 @@ def create_draft_parallel_config( elif speculative_draft_tensor_parallel_size != 1: # TODO(wooyeon): allow tp values larger than 1 raise ValueError( - f"{speculative_draft_tensor_parallel_size=} cannot be" + f"{speculative_draft_tensor_parallel_size=} cannot be " f"other value than 1") draft_parallel_config = ParallelConfig( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 48d01fcfd8f5f..315b2f50a919d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -129,6 +129,7 @@ class EngineArgs: guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None + speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None @@ -571,6 +572,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') + # Quantization settings for speculative model. + parser.add_argument( + '--speculative-model-quantization', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.speculative_model_quantization, + help='Method used to quantize the weights of speculative model.' + 'If None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') parser.add_argument( '--num-speculative-tokens', type=int, @@ -844,6 +857,8 @@ def create_engine_config(self, ) -> EngineConfig: target_parallel_config=parallel_config, target_dtype=self.dtype, speculative_model=self.speculative_model, + speculative_model_quantization = \ + self.speculative_model_quantization, speculative_draft_tensor_parallel_size = \ self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens,