Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import math
import os
import tempfile
from enum import Enum
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast

import numpy as np
import pytest
Expand Down Expand Up @@ -33,6 +34,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.sequence import Logprob
from vllm.transformers_utils.utils import maybe_model_redirect

logger = init_logger(__name__)
Expand Down Expand Up @@ -602,7 +604,7 @@ def _hidden_states_to_seq_logprobs(
def _hidden_states_to_logprobs(
self,
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
num_logprobs: int,
num_logprobs: Optional[int],
) -> tuple[list[dict[int, float]], int]:
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
output_len = len(hidden_states)
Expand Down Expand Up @@ -630,7 +632,7 @@ def generate_greedy_logprobs_limit(
self,
prompts: list[str],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
Expand Down Expand Up @@ -677,7 +679,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> list[TokensTextLogprobs]:
Expand Down Expand Up @@ -966,7 +968,7 @@ def generate_greedy_logprobs(
self,
prompts: list[str],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
Expand All @@ -991,11 +993,40 @@ def generate_greedy_logprobs(
videos=videos,
**kwargs)

def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
"""
Return the perplexity score associated with generating the prompts

:param prompts: list of prompts to score
:return: perplexity score of each prompt
"""
outputs = self.generate_greedy_logprobs(prompts,
max_tokens=1,
num_logprobs=None,
num_prompt_logprobs=0)

perplexities = []
for output in outputs:
output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
assert token_datas[0] is None
token_log_probs = []
for token_data in token_datas[1:]:
assert token_data is not None
assert len(token_data) == 1
token_log_prob = list(token_data.values())[0].logprob
token_log_probs.append(token_log_prob)

perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
perplexities.append(perplexity)

return perplexities

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> Union[list[TokensTextLogprobs],
Expand Down
22 changes: 22 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,25 @@ def check_model(model):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.parametrize("model,prompt,exp_perplexity", [
(
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
"Flat is better than nested.\nSparse is better than dense.",
150.0,
),
(
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
"Flat is better than nested.\nSparse is better than dense.",
150.0,
),
])
def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
exp_perplexity):
with vllm_runner(model, enforce_eager=True) as llm:
perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity)
assert perplexity <= exp_perplexity
16 changes: 15 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod",
Expand Down Expand Up @@ -199,6 +200,7 @@ def create_weights(self, layer: torch.nn.Module,
set_weight_attrs(weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# special postprocessing for CPU SGL
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
Expand Down Expand Up @@ -1470,7 +1472,7 @@ def __init__(self,
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
"weight_loader": self.weight_loader_v1,
})
else:
self.bias = None
Expand Down Expand Up @@ -1580,6 +1582,18 @@ def forward( # type: ignore[override]
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_v1(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)

def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel

import vllm.envs as envs
Expand All @@ -30,6 +31,8 @@
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
Expand Down Expand Up @@ -60,6 +63,7 @@ def __init__(
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[TransformConfig] = None,
):
super().__init__()
self.ignore = ignore
Expand All @@ -71,6 +75,12 @@ def __init__(
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config

if transform_config is not None:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)

Expand Down Expand Up @@ -103,18 +113,27 @@ def get_quant_method(
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config,
self.packed_modules_mapping)

# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)

# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)

else:
return quant_method

if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
Expand All @@ -129,6 +148,7 @@ def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")

return cls(
target_scheme_map=target_scheme_map,
Expand All @@ -137,6 +157,7 @@ def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)

@classmethod
Expand Down Expand Up @@ -537,9 +558,11 @@ def get_scheme(self,

# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None

# Will be empty for models with only sparsity
weight_quant = input_quant = None
Expand Down Expand Up @@ -722,7 +745,6 @@ def apply(self,
layer input. See LinearMethodBase for param details

"""

scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
Expand Down
Loading