Skip to content

Commit

Permalink
[CORE] Quantized lm-head Framework (#4442)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
  • Loading branch information
3 people authored Jul 2, 2024
1 parent 7c008c5 commit ee93f4f
Show file tree
Hide file tree
Showing 48 changed files with 268 additions and 121 deletions.
10 changes: 5 additions & 5 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def _pretest():

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)

original_weight = linear.weight.clone()
original_lm_head = deepcopy(linear)

linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
Expand All @@ -490,7 +490,7 @@ def _pretest():
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
Expand Down Expand Up @@ -519,11 +519,11 @@ def _pretest():

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)

rtol, atol = TOLERANCES[lora_result.dtype]
Expand Down
45 changes: 45 additions & 0 deletions tests/quantization/test_lm_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests whether gptq models with quantized lm_head can be loaded.
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from typing import Tuple

import pytest
import torch

from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod

PROMPT = "On the surface of Mars, we found"

MODELS_QUANT = [(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]


@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
def test_lm_head(
vllm_runner,
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)

lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)

if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)

print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
MAX_SPEC_TOKENS = 5

# precision
PRECISION = "float16"
PRECISION = "float32"


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def pick_ith(token_ids, logits):
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,11 +1172,11 @@ def set_mapping(
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata


Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self,

def forward(
self,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
Expand All @@ -52,8 +54,7 @@ def forward(
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)

logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
Expand All @@ -68,12 +69,13 @@ def forward(

return logits

def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")

@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default

@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs


Expand All @@ -24,10 +25,12 @@ def __init__(
weight_bits: int,
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
Expand All @@ -37,7 +40,8 @@ def __init__(
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")

@classmethod
def get_name(cls) -> str:
Expand All @@ -61,11 +65,14 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
return None

Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.utils import get_device_capability_stateless

logger = init_logger(__name__)
Expand Down Expand Up @@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""

def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
Expand All @@ -69,6 +70,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized

# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
Expand Down Expand Up @@ -96,7 +98,8 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -120,7 +123,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -145,7 +151,8 @@ def override_quantization_method(cls, hf_quant_cfg,
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self)
return None

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)
Expand All @@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
Expand All @@ -51,7 +54,8 @@ def __init__(
self.perm_len = 1024

def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size})"
return (f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -73,7 +77,9 @@ def get_config_filenames(cls) -> List[str]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(group_size, lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -96,7 +102,8 @@ def override_quantization_method(cls, hf_quant_cfg,

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self)
return None

Expand Down
Loading

0 comments on commit ee93f4f

Please sign in to comment.