1111 set_weight_attrs )
1212from vllm .model_executor .layers .quantization .base_config import (
1313 QuantizationConfig )
14+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
1415from vllm .utils import get_device_capability_stateless
1516
1617logger = init_logger (__name__ )
@@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
5960 """Config class for GPTQ Marlin"""
6061
6162 def __init__ (self , weight_bits : int , group_size : int , desc_act : bool ,
62- is_sym : bool ) -> None :
63+ is_sym : bool , lm_head_quantized : bool ) -> None :
6364 if desc_act and group_size == - 1 :
6465 # In this case, act_order == True is the same as act_order == False
6566 # (since we have only one group per output channel)
@@ -69,6 +70,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
6970 self .group_size = group_size
7071 self .desc_act = desc_act
7172 self .is_sym = is_sym
73+ self .lm_head_quantized = lm_head_quantized
7274
7375 # Verify
7476 if self .weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS :
@@ -96,7 +98,8 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
9698 def __repr__ (self ) -> str :
9799 return (f"GPTQMarlinConfig(weight_bits={ self .weight_bits } , "
98100 f"group_size={ self .group_size } , "
99- f"desc_act={ self .desc_act } )" )
101+ f"desc_act={ self .desc_act } , "
102+ f"lm_head_quantized={ self .lm_head_quantized } )" )
100103
101104 @classmethod
102105 def get_name (cls ) -> str :
@@ -120,7 +123,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
120123 group_size = cls .get_from_keys (config , ["group_size" ])
121124 desc_act = cls .get_from_keys (config , ["desc_act" ])
122125 is_sym = cls .get_from_keys (config , ["sym" ])
123- return cls (weight_bits , group_size , desc_act , is_sym )
126+ lm_head_quantized = cls .get_from_keys_or (config , ["lm_head" ],
127+ default = False )
128+ return cls (weight_bits , group_size , desc_act , is_sym ,
129+ lm_head_quantized )
124130
125131 @classmethod
126132 def override_quantization_method (cls , hf_quant_cfg ,
@@ -145,7 +151,8 @@ def override_quantization_method(cls, hf_quant_cfg,
145151 def get_quant_method (
146152 self ,
147153 layer : torch .nn .Module ) -> Optional ["GPTQMarlinLinearMethod" ]:
148- if isinstance (layer , LinearBase ):
154+ if (isinstance (layer , LinearBase ) or
155+ (isinstance (layer , ParallelLMHead ) and self .lm_head_quantized )):
149156 return GPTQMarlinLinearMethod (self )
150157 return None
151158
0 commit comments