Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit c1ac591

Browse files
QubitiumRobert ShawZX
committed
[CORE] Quantized lm-head Framework (vllm-project#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
1 parent 4d31612 commit c1ac591

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+268
-121
lines changed

tests/lora/test_layers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@ def _pretest():
489489

490490
lora_result = lora_logits_processor._get_logits(
491491
hidden_states=torch.cat(inputs),
492-
embedding=linear.weight,
492+
lm_head=linear,
493493
embedding_bias=None)
494494

495-
original_weight = linear.weight.clone()
495+
original_lm_head = deepcopy(linear)
496496

497497
linear.weight[logits_processor.
498498
org_vocab_size:logits_processor.org_vocab_size +
@@ -504,7 +504,7 @@ def _pretest():
504504
for input_, lora_id in zip(inputs, prompt_mapping):
505505
lora = lora_dict[lora_id]
506506
result = logits_processor._get_logits(hidden_states=input_,
507-
embedding=linear.weight,
507+
lm_head=linear,
508508
embedding_bias=None)
509509
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
510510
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
@@ -533,11 +533,11 @@ def _pretest():
533533

534534
lora_result = lora_logits_processor._get_logits(
535535
hidden_states=torch.cat(inputs),
536-
embedding=original_weight,
536+
lm_head=original_lm_head,
537537
embedding_bias=None)[:, :vocab_size]
538538
expected_result = logits_processor._get_logits(
539539
hidden_states=torch.cat(inputs),
540-
embedding=original_weight,
540+
lm_head=original_lm_head,
541541
embedding_bias=None)
542542

543543
rtol, atol = TOLERANCES[lora_result.dtype]

tests/quantization/test_lm_head.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Tests whether gptq models with quantized lm_head can be loaded.
2+
3+
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
4+
"""
5+
from typing import Tuple
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
11+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
12+
from vllm.model_executor.layers.quantization.gptq_marlin import (
13+
GPTQMarlinLinearMethod)
14+
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
15+
16+
PROMPT = "On the surface of Mars, we found"
17+
18+
MODELS_QUANT = [(
19+
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
20+
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
21+
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
22+
23+
24+
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
25+
def test_lm_head(
26+
vllm_runner,
27+
model_lm_head_quant: Tuple[str, bool],
28+
) -> None:
29+
model, lm_head_quantized = model_lm_head_quant
30+
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
31+
32+
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
33+
model_runner.model.lm_head)
34+
35+
if lm_head_quantized:
36+
assert isinstance(
37+
lm_head_layer.linear_method,
38+
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
39+
else:
40+
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)
41+
42+
print(
43+
vllm_model.generate_greedy(prompts=["Hello my name is"],
44+
max_tokens=10)[0][1])
45+
del vllm_model

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
MAX_SPEC_TOKENS = 5
4141

4242
# precision
43-
PRECISION = "float16"
43+
PRECISION = "float32"
4444

4545

4646
@pytest.mark.parametrize(

tests/test_logits_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def pick_ith(token_ids, logits):
8383
device=device,
8484
pin_memory=is_pin_memory_available())
8585
logits_processor_output = logits_processor(
86-
embedding=None,
86+
lm_head=None,
8787
hidden_states=input_tensor,
8888
sampling_metadata=sampling_metadata)
8989

vllm/lora/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,11 +1172,11 @@ def set_mapping(
11721172
def _get_logits(
11731173
self,
11741174
hidden_states: torch.Tensor,
1175-
embedding: torch.Tensor,
1175+
lm_head: VocabParallelEmbedding,
11761176
embedding_bias: Optional[torch.Tensor] = None,
11771177
) -> Optional[torch.Tensor]:
11781178
# Get the logits for the next tokens.
1179-
logits = torch.matmul(hidden_states, embedding.t())
1179+
logits = lm_head.linear_method.apply(lm_head, hidden_states)
11801180
if embedding_bias is not None:
11811181
logits += embedding_bias
11821182
logits = tensor_model_parallel_gather(logits)

vllm/model_executor/layers/logits_processor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn as nn
77

88
from vllm.distributed import tensor_model_parallel_gather
9+
from vllm.model_executor.layers.vocab_parallel_embedding import (
10+
VocabParallelEmbedding)
911
from vllm.model_executor.sampling_metadata import SamplingMetadata
1012

1113

@@ -40,7 +42,7 @@ def __init__(self,
4042

4143
def forward(
4244
self,
43-
embedding: torch.Tensor,
45+
lm_head: VocabParallelEmbedding,
4446
hidden_states: torch.Tensor,
4547
sampling_metadata: SamplingMetadata,
4648
embedding_bias: Optional[torch.Tensor] = None,
@@ -52,8 +54,7 @@ def forward(
5254
sampling_metadata)
5355

5456
# Get the logits for the next tokens.
55-
logits = self._get_logits(hidden_states, embedding, embedding_bias)
56-
57+
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
5758
if logits is not None:
5859
if self.soft_cap is not None:
5960
logits = logits / self.soft_cap
@@ -68,12 +69,13 @@ def forward(
6869

6970
return logits
7071

71-
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
72+
def _get_logits(self, hidden_states: torch.Tensor,
73+
lm_head: VocabParallelEmbedding,
7274
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
7375
# Get the logits for the next tokens.
74-
logits = torch.matmul(hidden_states, embedding.t())
75-
if embedding_bias is not None:
76-
logits += embedding_bias
76+
logits = lm_head.linear_method.apply(lm_head,
77+
hidden_states,
78+
bias=embedding_bias)
7779
logits = tensor_model_parallel_gather(logits)
7880
# Remove paddings in vocab (if any).
7981
if logits is not None:

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
8787
raise ValueError(f"Cannot find any of {keys} in the model's "
8888
"quantization config.")
8989

90+
@staticmethod
91+
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
92+
default: Any) -> Any:
93+
"""Get a optional value from the model's quantization config."""
94+
try:
95+
return QuantizationConfig.get_from_keys(config, keys)
96+
except ValueError:
97+
return default
98+
9099
@abstractmethod
91100
def get_quant_method(
92101
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:

vllm/model_executor/layers/quantization/gptq.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
1111
from vllm.model_executor.layers.quantization.base_config import (
1212
QuantizationConfig)
13+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1314
from vllm.model_executor.utils import set_weight_attrs
1415

1516

@@ -24,10 +25,12 @@ def __init__(
2425
weight_bits: int,
2526
group_size: int,
2627
desc_act: bool,
28+
lm_head_quantized: bool,
2729
) -> None:
2830
self.weight_bits = weight_bits
2931
self.group_size = group_size
3032
self.desc_act = desc_act
33+
self.lm_head_quantized = lm_head_quantized
3134
self.pack_factor = Fraction(32, self.weight_bits)
3235
if self.weight_bits not in [2, 3, 4, 8]:
3336
raise ValueError(
@@ -37,7 +40,8 @@ def __init__(
3740
def __repr__(self) -> str:
3841
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
3942
f"group_size={self.group_size}, "
40-
f"desc_act={self.desc_act})")
43+
f"desc_act={self.desc_act}),"
44+
f"lm_head_quantized={self.lm_head_quantized}")
4145

4246
@classmethod
4347
def get_name(cls) -> str:
@@ -61,11 +65,14 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
6165
weight_bits = cls.get_from_keys(config, ["bits"])
6266
group_size = cls.get_from_keys(config, ["group_size"])
6367
desc_act = cls.get_from_keys(config, ["desc_act"])
64-
return cls(weight_bits, group_size, desc_act)
68+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
69+
default=False)
70+
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
6571

6672
def get_quant_method(
6773
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
68-
if isinstance(layer, LinearBase):
74+
if (isinstance(layer, LinearBase) or
75+
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
6976
return GPTQLinearMethod(self)
7077
return None
7178

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
set_weight_attrs)
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig)
14+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1415
from vllm.utils import get_device_capability_stateless
1516

1617
logger = init_logger(__name__)
@@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
5960
"""Config class for GPTQ Marlin"""
6061

6162
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
62-
is_sym: bool) -> None:
63+
is_sym: bool, lm_head_quantized: bool) -> None:
6364
if desc_act and group_size == -1:
6465
# In this case, act_order == True is the same as act_order == False
6566
# (since we have only one group per output channel)
@@ -69,6 +70,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
6970
self.group_size = group_size
7071
self.desc_act = desc_act
7172
self.is_sym = is_sym
73+
self.lm_head_quantized = lm_head_quantized
7274

7375
# Verify
7476
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
@@ -96,7 +98,8 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
9698
def __repr__(self) -> str:
9799
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
98100
f"group_size={self.group_size}, "
99-
f"desc_act={self.desc_act})")
101+
f"desc_act={self.desc_act}, "
102+
f"lm_head_quantized={self.lm_head_quantized})")
100103

101104
@classmethod
102105
def get_name(cls) -> str:
@@ -120,7 +123,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
120123
group_size = cls.get_from_keys(config, ["group_size"])
121124
desc_act = cls.get_from_keys(config, ["desc_act"])
122125
is_sym = cls.get_from_keys(config, ["sym"])
123-
return cls(weight_bits, group_size, desc_act, is_sym)
126+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
127+
default=False)
128+
return cls(weight_bits, group_size, desc_act, is_sym,
129+
lm_head_quantized)
124130

125131
@classmethod
126132
def override_quantization_method(cls, hf_quant_cfg,
@@ -145,7 +151,8 @@ def override_quantization_method(cls, hf_quant_cfg,
145151
def get_quant_method(
146152
self,
147153
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
148-
if isinstance(layer, LinearBase):
154+
if (isinstance(layer, LinearBase) or
155+
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
149156
return GPTQMarlinLinearMethod(self)
150157
return None
151158

vllm/model_executor/layers/quantization/marlin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
99
from vllm.model_executor.layers.quantization.base_config import (
1010
QuantizationConfig)
11+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1112
from vllm.model_executor.utils import set_weight_attrs
1213

1314
logger = init_logger(__name__)
@@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
2223
def __init__(
2324
self,
2425
group_size: int,
26+
lm_head_quantized: bool,
2527
) -> None:
2628
# Group size for the quantization.
2729
self.group_size = group_size
30+
self.lm_head_quantized = lm_head_quantized
2831
if self.group_size != 128 and self.group_size != -1:
2932
raise ValueError(
3033
"Currently, only group size 128 and -1 (channelwise) "
@@ -51,7 +54,8 @@ def __init__(
5154
self.perm_len = 1024
5255

5356
def __repr__(self) -> str:
54-
return f"MarlinConfig(group_size={self.group_size})"
57+
return (f"MarlinConfig(group_size={self.group_size}, "
58+
f"lm_head_quantized={self.lm_head_quantized})")
5559

5660
@classmethod
5761
def get_name(cls) -> str:
@@ -73,7 +77,9 @@ def get_config_filenames(cls) -> List[str]:
7377
@classmethod
7478
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
7579
group_size = cls.get_from_keys(config, ["group_size"])
76-
return cls(group_size)
80+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
81+
default=False)
82+
return cls(group_size, lm_head_quantized)
7783

7884
@classmethod
7985
def override_quantization_method(cls, hf_quant_cfg,
@@ -96,7 +102,8 @@ def override_quantization_method(cls, hf_quant_cfg,
96102

97103
def get_quant_method(
98104
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
99-
if isinstance(layer, LinearBase):
105+
if (isinstance(layer, LinearBase) or
106+
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
100107
return MarlinLinearMethod(self)
101108
return None
102109

0 commit comments

Comments
 (0)