diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 2e51e95a38f2e..7207af6b1a4b3 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -475,10 +475,10 @@ def _pretest(): lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=linear.weight, + lm_head=linear, embedding_bias=None) - original_weight = linear.weight.clone() + original_lm_head = deepcopy(linear) linear.weight[logits_processor. org_vocab_size:logits_processor.org_vocab_size + @@ -490,7 +490,7 @@ def _pretest(): for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits(hidden_states=input_, - embedding=linear.weight, + lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling @@ -519,11 +519,11 @@ def _pretest(): lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=original_weight, + lm_head=original_lm_head, embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), - embedding=original_weight, + lm_head=original_lm_head, embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py new file mode 100644 index 0000000000000..dd9a016807df9 --- /dev/null +++ b/tests/quantization/test_lm_head.py @@ -0,0 +1,45 @@ +"""Tests whether gptq models with quantized lm_head can be loaded. + +Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`. +""" +from typing import Tuple + +import pytest +import torch + +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod + +PROMPT = "On the surface of Mars, we found" + +MODELS_QUANT = [( + "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse", + True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)] + + +@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT) +def test_lm_head( + vllm_runner, + model_lm_head_quant: Tuple[str, bool], +) -> None: + model, lm_head_quantized = model_lm_head_quant + vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048) + + lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model.lm_head) + + if lm_head_quantized: + assert isinstance( + lm_head_layer.linear_method, + (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) + else: + assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod) + + print( + vllm_model.generate_greedy(prompts=["Hello my name is"], + max_tokens=10)[0][1]) + del vllm_model diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 9a9f2acbb8f39..dd67a7735a647 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -34,7 +34,7 @@ MAX_SPEC_TOKENS = 5 # precision -PRECISION = "float16" +PRECISION = "float32" @pytest.mark.parametrize( diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 4ee980505a3ab..8ee2d78190cd1 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -83,7 +83,7 @@ def pick_ith(token_ids, logits): device=device, pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( - embedding=None, + lm_head=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 2fddfccaf1e4c..0a63f9ef012bc 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1172,11 +1172,11 @@ def set_mapping( def _get_logits( self, hidden_states: torch.Tensor, - embedding: torch.Tensor, + lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) + logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8062bfb5194bc..f6fcf49ef464b 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,6 +6,8 @@ import torch.nn as nn from vllm.distributed import tensor_model_parallel_gather +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -40,7 +42,7 @@ def __init__(self, def forward( self, - embedding: torch.Tensor, + lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, @@ -52,8 +54,7 @@ def forward( sampling_metadata) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - + logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap @@ -68,12 +69,13 @@ def forward( return logits - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + def _get_logits(self, hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index c23b66161d9b8..1607470cb76f6 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -87,6 +87,15 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.") + @staticmethod + def get_from_keys_or(config: Dict[str, Any], keys: List[str], + default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + @abstractmethod def get_quant_method( self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index ae9f7019f0592..595d6ab96b1b9 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.utils import set_weight_attrs @@ -24,10 +25,12 @@ def __init__( weight_bits: int, group_size: int, desc_act: bool, + lm_head_quantized: bool, ) -> None: self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized self.pack_factor = Fraction(32, self.weight_bits) if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( @@ -37,7 +40,8 @@ def __init__( def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act})," + f"lm_head_quantized={self.lm_head_quantized}") @classmethod def get_name(cls) -> str: @@ -61,11 +65,14 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - return cls(weight_bits, group_size, desc_act) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized) def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c6e9279c8baea..97aae33f133be 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -11,6 +11,7 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.utils import get_device_capability_stateless logger = init_logger(__name__) @@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool) -> None: + is_sym: bool, lm_head_quantized: bool) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -69,6 +70,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.group_size = group_size self.desc_act = desc_act self.is_sym = is_sym + self.lm_head_quantized = lm_head_quantized # Verify if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: @@ -96,7 +98,8 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __repr__(self) -> str: return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: @@ -120,7 +123,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - return cls(weight_bits, group_size, desc_act, is_sym) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -145,7 +151,8 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQMarlinLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 3613c9d9ecf2a..f0a9cf5520bdd 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig): def __init__( self, group_size: int, + lm_head_quantized: bool, ) -> None: # Group size for the quantization. self.group_size = group_size + self.lm_head_quantized = lm_head_quantized if self.group_size != 128 and self.group_size != -1: raise ValueError( "Currently, only group size 128 and -1 (channelwise) " @@ -51,7 +54,8 @@ def __init__( self.perm_len = 1024 def __repr__(self) -> str: - return f"MarlinConfig(group_size={self.group_size})" + return (f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: @@ -73,7 +77,9 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) - return cls(group_size) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(group_size, lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -96,7 +102,8 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: - if isinstance(layer, LinearBase): + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return MarlinLinearMethod(self) return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 4650b2c2458d0..d70eb1c2704b4 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -8,6 +8,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -157,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module): params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. + quant_config: quant config for the layer """ # noqa: E501 def __init__(self, @@ -164,7 +168,8 @@ def __init__(self, embedding_dim: int, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None): super().__init__() # Keep the input dimensions. @@ -187,6 +192,14 @@ def __init__(self, self.org_vocab_size, tp_rank, self.tp_size) self.embedding_dim = embedding_dim + + linear_method = None + if quant_config is not None: + linear_method = quant_config.get_quant_method(self) + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method: QuantizeMethodBase = linear_method + if params_dtype is None: params_dtype = torch.get_default_dtype() # Divide the weight matrix along the vocaburaly dimension. @@ -201,14 +214,14 @@ def __init__(self, self.num_added_embeddings_per_partition = ( self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index) - self.weight = Parameter( - torch.empty(self.num_embeddings_per_partition, - self.embedding_dim, - dtype=params_dtype)) - set_weight_attrs(self.weight, { - "parallel_dim": 0, - "weight_loader": self.weight_loader - }) + + self.linear_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, @@ -288,10 +301,32 @@ def get_sharded_to_full_mapping(self) -> Optional[List[int]]: return ret def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.org_vocab_size - loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index: - self.shard_indices.org_vocab_end_index] + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // + param.pack_factor) + start_idx = start_idx // param.pack_factor + shard_size = shard_size // param.pack_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) @@ -346,16 +381,17 @@ def __init__(self, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size) + org_num_embeddings, padding_size, quant_config) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)) set_weight_attrs(self.bias, { - "parallel_dim": 0, - "weight_loader": self.weight_loader + "output_dim": 0, + "weight_loader": self.weight_loader, }) else: self.register_parameter("bias", None) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index fec52e0168851..49e57a847e847 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -412,6 +412,7 @@ def __init__(self, self.lm_head = ParallelLMHead( self.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok @@ -434,7 +435,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ddc4e908451af..e1ea8bfcac655 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -328,7 +328,9 @@ def __init__( self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -346,7 +348,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 8387c8e37bdd3..86ae32e0cb01f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -276,7 +276,7 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.word_embeddings.weight + self.lm_head = self.transformer.word_embeddings self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -294,7 +294,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e6012a6d4e784..553ddf90475b4 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -303,7 +303,8 @@ def __init__( self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, - config.hidden_size) + config.hidden_size, + quant_config=quant_config) def forward( self, @@ -355,7 +356,7 @@ def __init__( self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.output_layer.weight + self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() @@ -373,7 +374,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 2961f421eb6fc..5f6e3a134f408 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -363,12 +363,12 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: is_not_lora = hasattr(self.model.embed_tokens, 'weight') if is_not_lora: - embedding_weights = self.model.embed_tokens.weight + logits = self.logits_processor(self.model.embed_tokens, + hidden_states, sampling_metadata) else: - embedding_weights = self.model.embed_tokens.base_layer.weight + logits = self.logits_processor(self.model.embed_tokens.base_layer, + hidden_states, sampling_metadata) - logits = self.logits_processor(embedding_weights, hidden_states, - sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 210cf61652661..d758333b22388 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -370,6 +370,7 @@ def __init__( config.d_model, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -389,7 +390,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e9ceca9b18c35..3fd6f2218f3eb 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -377,7 +377,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = DeepseekModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -395,7 +397,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3cf62afd9b4ac..fb4097fd1e9b3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -465,7 +465,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = DeepseekV2Model(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -483,7 +485,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 89b0bbf014dea..93f07327eaa26 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -394,13 +394,13 @@ def __init__( if config.tie_word_embeddings is not None else True) if self.tie_word_embeddings: - self.lm_head_weight = self.transformer.word_embeddings.weight + self.lm_head = self.transformer.word_embeddings else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -422,7 +422,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 0a5a7ed3d04e4..b603a59110915 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -347,8 +347,8 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 1f921c8bd0953..8fedff6255053 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -346,8 +346,8 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) return logits def sample( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 55f2e27410dd7..be19f4ba8c71e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -238,7 +238,7 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = GPT2Model(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -256,7 +256,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 7d0bf39c58f42..cc42413d53f4c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -259,7 +259,7 @@ def __init__( self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -281,7 +281,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index de7f86af709e8..4bb9debe7ae81 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -229,6 +229,7 @@ def __init__( config.vocab_size, config.n_embd, bias=True, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -247,7 +248,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 3658b8fbf057e..b306574b2ed92 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -241,6 +241,7 @@ def __init__( self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -259,7 +260,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.embed_out.weight, hidden_states, + logits = self.logits_processor(self.embed_out, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 283bc064b596c..22132f40fc5e6 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -253,7 +253,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = InternLM2Model(config, cache_config, quant_config) - self.output = ParallelLMHead(config.vocab_size, config.hidden_size) + self.output = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -271,7 +273,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.output.weight, hidden_states, + logits = self.logits_processor(self.output, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 2758e2d0b59af..0030c761d34db 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -273,7 +273,7 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: @@ -297,7 +297,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index af75b6bee1041..77edcd7402db1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -380,6 +380,7 @@ def __init__( # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -403,7 +404,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 39c47dddf5070..bbec4dbd897c2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -125,7 +125,8 @@ def __init__(self, self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size) + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -255,7 +256,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 8b078391b3497..f67598c4004b3 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -186,7 +186,8 @@ def __init__(self, self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size) + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) @@ -438,7 +439,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 33020432713fb..4ccf1cf0fad76 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -449,6 +449,7 @@ def __init__( # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) self.scale_width = self.config.hidden_size / self.config.dim_model_base @@ -472,10 +473,10 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - logits = self.logits_processor(lm_head_weight, hidden_states, + lm_head = self.lm_head + logits = self.logits_processor(lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5144e7ea4b803..7f5e3b9699c91 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -331,6 +331,7 @@ def __init__( # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -350,7 +351,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index dde2da20b3b98..10faa5cc6b6cc 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -344,7 +344,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = MixtralModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -362,7 +364,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 290a703af6ffa..97f7ec74292bb 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig @@ -87,7 +87,7 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: self.proj = nn.ModuleList([proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)) - head = nn.Linear(self.inner_dim, self.vocab_size, bias=False) + head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) self.head = nn.ModuleList([head] * self.max_speculative_tokens) ln = MLPSpeculatorLayerNorm(self.inner_dim, @@ -169,8 +169,8 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head previous_hidden_states = states - logits = self.logits_processor(self.head[head_index].weight, - states, sampling_metadata) + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) output = self.sampler(logits.flatten(0, 1), sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 28dc5922cfe9c..7d658b39e6794 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -263,7 +263,7 @@ def __init__( self.quant_config = quant_config self.transformer = MPTModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.wte.weight + self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -281,7 +281,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 53215f32b92a3..408c0c883a9d0 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -283,15 +283,15 @@ def __init__(self, self.config = config self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -313,7 +313,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d12a51af5a781..edc16710c0229 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -294,7 +294,7 @@ def __init__( self.config = config self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) - self.lm_head_weight = self.model.decoder.embed_tokens.weight + self.lm_head = self.model.decoder.embed_tokens self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -312,7 +312,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index a298f0307f3a0..8159cc13fba0b 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -259,7 +259,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = OrionModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -277,7 +279,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index cc8e31fe1adb9..ac7496f68fd99 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -268,7 +268,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - bias=True) + bias=True, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -287,7 +288,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 706ae65201d9f..cc06929fefab4 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -366,6 +366,7 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -400,7 +401,7 @@ def get_decoder(self): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index eff4e50294b3a..d73a42026bc32 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -365,7 +365,9 @@ def __init__(self, self.model = LlamaModel(config, cache_config, quant_config) self.vision_embed_tokens = Phi3HDImageEmbedding( vlm_config, config, self.model.embed_tokens) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -409,7 +411,7 @@ def forward(self, def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 408c206c5e1ec..47c85c783db7a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -235,7 +235,9 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -253,7 +255,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3691a3d2e3614..e9ae2192f280d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -316,11 +316,11 @@ def __init__( self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size) - self.lm_head_weight = self.lm_head.weight + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -339,7 +339,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8decb4464fb36..ccaa6f20893e0 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -362,7 +362,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = Qwen2MoeModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -380,7 +382,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1098b3031b1e8..5451b56ed05f7 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -240,7 +240,9 @@ def __init__( self.config = config self.quant_config = quant_config self.model = StableLMEpochModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -258,7 +260,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 6f3d5d51d0315..1752bfd473b88 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -242,7 +242,7 @@ def __init__(self, self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: - self.lm_head_weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens else: self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( @@ -250,8 +250,8 @@ def __init__(self, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, ) - self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -270,7 +270,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 08d3efd3312b9..84f0ffc376d65 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -310,7 +310,9 @@ def __init__( self.quant_config = quant_config self.model = XverseModel(config, cache_config, quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -328,7 +330,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits