Skip to content

[Quant] Add SupportsQuant to phi3 and clip #13104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
num_codebooks: int,
out_group_size: int,
) -> None:
super().__init__()
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Type
from typing import Any, Dict, List, Optional, Type

import torch
from torch import nn
Expand Down Expand Up @@ -59,7 +59,11 @@ def method_has_implemented_embedding(

class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()

def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()

@abstractmethod
def get_name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:

super().__init__()
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):

super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
weight_bits: int = 8,
group_size: int = 512,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16]
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""

def __init__(self) -> None:
pass
super().__init__()

@classmethod
def get_name(cls) -> str:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""

def __init__(self, ignore_list: List[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""

def __init__(self, ) -> None:
pass
super().__init__()

def __repr__(self) -> str:
return ("GGUFConfig()")
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
super().__init__()
self.dynamic = dynamic

self.weight_bits = weight_bits
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
weight_bits: int,
group_size: int,
) -> None:
super().__init__()
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
group_size: int,
skip_modules: Optional[List[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization "
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
super().__init__()
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, linear_quant_method: str, weight_bits: int,
group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.has_zp = has_zp
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/neuron_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
super().__init__()
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
group_size: int,
is_sym: bool = True,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.is_sym = is_sym
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self,
kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None,
pack_method: str = "reorder"):
super().__init__()
if kv_cache_group is None:
kv_cache_group = []
self.quant_config = quant_config
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/tpu_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
self,
activation_scheme: str = "none",
) -> None:
super().__init__()
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant

from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs

Expand Down Expand Up @@ -335,10 +336,10 @@ def forward(
return encoder_outputs


class CLIPVisionModel(nn.Module):

class CLIPVisionModel(nn.Module, SupportsQuant):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

def __init__(
self,
Expand Down
32 changes: 32 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw

from .interfaces_base import is_pooling_model
Expand Down Expand Up @@ -443,6 +445,36 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)


class SupportsQuant:
"""The interface required for all models that support quantization."""

packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
quant_config: Optional[QuantizationConfig] = None

def __new__(cls, *args, **kwargs) -> "SupportsQuant":
instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)
return instance

@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
from vllm.config import VllmConfig # avoid circular import

args_values = list(args) + list(kwargs.values())
for arg in args_values:
if isinstance(arg, VllmConfig):
return arg.quant_config

if isinstance(arg, QuantizationConfig):
return arg

return None


@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.utils import is_list_of

from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -498,7 +498,8 @@ def _apply_prompt_replacements(
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
Expand All @@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
Expand All @@ -520,14 +520,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)

# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
quant_config,
self.quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))

self.language_model = init_vllm_registered_model(
Expand Down