Skip to content

Commit 3eed01d

Browse files
Isotr0pyDarkLight1337
authored andcommitted
[Model] Support GGUF models newly added in transformers 4.46.0 (vllm-project#9685)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 55ab59d commit 3eed01d

File tree

7 files changed

+162
-87
lines changed

7 files changed

+162
-87
lines changed

examples/offline_inference/gguf_inference.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,20 @@
33
from vllm import LLM, SamplingParams
44

55

6-
def run_gguf_inference(model_path):
7-
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
8-
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
6+
def run_gguf_inference(model_path, tokenizer):
97
# Sample prompts.
108
prompts = [
119
"How many helicopters can a human eat in one sitting?",
1210
"What's the future of AI?",
1311
]
14-
prompts = [
15-
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
16-
for prompt in prompts
17-
]
12+
prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
1813
# Create a sampling params object.
1914
sampling_params = SamplingParams(temperature=0, max_tokens=128)
2015

2116
# Create an LLM.
22-
llm = LLM(model=model_path,
23-
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
24-
gpu_memory_utilization=0.95)
17+
llm = LLM(model=model_path, tokenizer=tokenizer)
2518

26-
outputs = llm.generate(prompts, sampling_params)
19+
outputs = llm.chat(prompts, sampling_params)
2720
# Print the outputs.
2821
for output in outputs:
2922
prompt = output.prompt
@@ -32,7 +25,8 @@ def run_gguf_inference(model_path):
3225

3326

3427
if __name__ == "__main__":
35-
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
36-
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
28+
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
29+
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
30+
tokenizer = "microsoft/Phi-3-medium-4k-instruct"
3731
model = hf_hub_download(repo_id, filename=filename)
38-
run_gguf_inference(model)
32+
run_gguf_inference(model, tokenizer)

tests/models/decoder_only/language/test_gguf.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,90 @@
44
"""
55

66
import os
7+
from typing import List, NamedTuple, Type
78

89
import pytest
910
from huggingface_hub import hf_hub_download
1011
from transformers import AutoTokenizer
1112

1213
from tests.quantization.utils import is_quant_method_supported
1314

15+
from ....conftest import VllmRunner
1416
from ...utils import check_logprobs_close
1517

1618
os.environ["TOKENIZERS_PARALLELISM"] = "true"
1719

1820
MAX_MODEL_LEN = 1024
1921

2022

23+
class GGUFTestConfig(NamedTuple):
24+
original_model: str
25+
gguf_repo: str
26+
gguf_filename: str
27+
28+
@property
29+
def gguf_model(self):
30+
return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
31+
32+
33+
LLAMA_CONFIG = GGUFTestConfig(
34+
original_model="meta-llama/Llama-3.2-1B-Instruct",
35+
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
36+
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
37+
)
38+
39+
QWEN2_CONFIG = GGUFTestConfig(
40+
original_model="Qwen/Qwen2.5-1.5B-Instruct",
41+
gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
42+
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
43+
)
44+
45+
PHI3_CONFIG = GGUFTestConfig(
46+
original_model="microsoft/Phi-3.5-mini-instruct",
47+
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
48+
gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
49+
)
50+
51+
GPT2_CONFIG = GGUFTestConfig(
52+
original_model="openai-community/gpt2-large",
53+
gguf_repo="QuantFactory/gpt2-large-GGUF",
54+
gguf_filename="gpt2-large.Q4_K_M.gguf",
55+
)
56+
57+
STABLELM_CONFIG = GGUFTestConfig(
58+
original_model="stabilityai/stablelm-3b-4e1t",
59+
gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
60+
gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
61+
)
62+
63+
STARCODER_CONFIG = GGUFTestConfig(
64+
original_model="bigcode/starcoder2-3b",
65+
gguf_repo="QuantFactory/starcoder2-3b-GGUF",
66+
gguf_filename="starcoder2-3b.Q6_K.gguf",
67+
)
68+
69+
MODELS = [
70+
LLAMA_CONFIG,
71+
QWEN2_CONFIG,
72+
PHI3_CONFIG,
73+
GPT2_CONFIG,
74+
STABLELM_CONFIG,
75+
# STARCODER_CONFIG, # broken
76+
]
77+
78+
2179
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
2280
reason="gguf is not supported on this GPU type.")
23-
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
24-
("meta-llama/Llama-3.2-1B-Instruct",
25-
"bartowski/Llama-3.2-1B-Instruct-GGUF",
26-
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
27-
("meta-llama/Llama-3.2-1B-Instruct",
28-
"bartowski/Llama-3.2-1B-Instruct-GGUF",
29-
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
30-
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
31-
"qwen2-1_5b-instruct-q4_k_m.gguf"),
32-
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
33-
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
34-
])
81+
@pytest.mark.parametrize("model", MODELS)
3582
@pytest.mark.parametrize("dtype", ["half"])
3683
@pytest.mark.parametrize("max_tokens", [32])
3784
@pytest.mark.parametrize("num_logprobs", [5])
3885
@pytest.mark.parametrize("tp_size", [1, 2])
3986
def test_models(
40-
num_gpus_available,
41-
vllm_runner,
42-
example_prompts,
43-
original_model,
44-
gguf_id,
45-
gguf_path,
87+
num_gpus_available: int,
88+
vllm_runner: Type[VllmRunner],
89+
example_prompts: List[str],
90+
model: GGUFTestConfig,
4691
dtype: str,
4792
max_tokens: int,
4893
num_logprobs: int,
@@ -51,28 +96,26 @@ def test_models(
5196
if num_gpus_available < tp_size:
5297
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
5398

54-
gguf_model = hf_hub_download(gguf_id, filename=gguf_path)
55-
56-
tokenizer = AutoTokenizer.from_pretrained(original_model)
57-
messages = [[{
58-
'role': 'user',
59-
'content': prompt
60-
}] for prompt in example_prompts]
61-
example_prompts = tokenizer.apply_chat_template(messages,
62-
tokenize=False,
63-
add_generation_prompt=True)
99+
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
100+
if tokenizer.chat_template is not None:
101+
messages = [[{
102+
'role': 'user',
103+
'content': prompt
104+
}] for prompt in example_prompts]
105+
example_prompts = tokenizer.apply_chat_template(
106+
messages, tokenize=False, add_generation_prompt=True)
64107

65108
# Run unquantized model.
66-
with vllm_runner(model_name=original_model,
109+
with vllm_runner(model_name=model.original_model,
67110
dtype=dtype,
68111
max_model_len=MAX_MODEL_LEN,
69112
tensor_parallel_size=tp_size) as original_model:
70-
71113
original_outputs = original_model.generate_greedy_logprobs(
72114
example_prompts[:-1], max_tokens, num_logprobs)
73115

74116
# Run gguf model.
75-
with vllm_runner(model_name=gguf_model,
117+
with vllm_runner(model_name=model.gguf_model,
118+
tokenizer_name=model.original_model,
76119
dtype=dtype,
77120
max_model_len=MAX_MODEL_LEN,
78121
tensor_parallel_size=tp_size) as gguf_model:

vllm/model_executor/layers/linear.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,14 @@ def weight_loader(self,
447447
is_gguf_weight = getattr(param, "is_gguf_weight", False)
448448
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
449449
if is_gguf_weight_type:
450-
param.data[loaded_shard_id].copy_(loaded_weight)
451-
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
450+
if loaded_shard_id is not None:
451+
param.data[loaded_shard_id].copy_(loaded_weight)
452+
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
453+
else:
454+
param.shard_weight_type = {
455+
i: loaded_weight.item()
456+
for i, _ in enumerate(self.output_sizes)
457+
}
452458
return
453459

454460
if is_gguf_weight:
@@ -459,15 +465,15 @@ def weight_loader(self,
459465
shard_size = loaded_weight.size(output_dim) // tp_size
460466
start_idx = tp_rank * shard_size
461467

462-
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
463-
shard_size)
464-
465-
param.shard_id.append(loaded_shard_id)
466-
param.shard_id_map[loaded_shard_id] = len(param.data_container)
467-
param.data_container.append(loaded_weight)
468-
if len(param.data_container) == 2:
469-
self.qweight = param.materialize_nested()
470-
return
468+
if loaded_shard_id is not None:
469+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
470+
shard_size)
471+
param.shard_id.append(loaded_shard_id)
472+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
473+
param.data_container.append(loaded_weight)
474+
if len(param.data_container) == 2:
475+
self.qweight = param.materialize_nested()
476+
return
471477

472478
param_data = param.data
473479
output_dim = getattr(param, "output_dim", None)
@@ -811,10 +817,16 @@ def weight_loader(self,
811817
# initialize GGUF param after we know the quantize type
812818
is_gguf_weight = getattr(param, "is_gguf_weight", False)
813819
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
814-
if is_gguf_weight_type and loaded_shard_id is not None:
820+
if is_gguf_weight_type:
815821
idx_map = {"q": 0, "k": 1, "v": 2}
816-
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
817-
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
822+
if loaded_shard_id is not None:
823+
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
824+
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
825+
else:
826+
param.shard_weight_type = {
827+
k: loaded_weight.item()
828+
for k in idx_map
829+
}
818830
return
819831

820832
if is_gguf_weight:
@@ -825,15 +837,15 @@ def weight_loader(self,
825837
shard_size = loaded_weight.size(output_dim) // tp_size
826838
start_idx = tp_rank * shard_size
827839

828-
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
829-
shard_size)
830-
831-
param.shard_id.append(loaded_shard_id)
832-
param.shard_id_map[loaded_shard_id] = len(param.data_container)
833-
param.data_container.append(loaded_weight)
834-
if len(param.data_container) == 3:
835-
self.qweight = param.materialize_nested()
836-
return
840+
if loaded_shard_id is not None:
841+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
842+
shard_size)
843+
param.shard_id.append(loaded_shard_id)
844+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
845+
param.data_container.append(loaded_weight)
846+
if len(param.data_container) == 3:
847+
self.qweight = param.materialize_nested()
848+
return
837849

838850
param_data = param.data
839851
output_dim = getattr(param, "output_dim", None)

vllm/model_executor/models/gpt2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198198
assert not config.scale_attn_by_inverse_layer_idx
199199
assert not config.reorder_and_upcast_attn
200200
self.embed_dim = config.hidden_size
201-
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
201+
self.wte = VocabParallelEmbedding(config.vocab_size,
202+
self.embed_dim,
203+
quant_config=quant_config,
204+
prefix=f"{prefix}.wte")
202205
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
203206
self.start_layer, self.end_layer, self.h = make_layers(
204207
config.num_hidden_layers,
@@ -259,7 +262,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
259262
self.lm_head = self.transformer.wte
260263
else:
261264
self.lm_head = ParallelLMHead(self.config.vocab_size,
262-
self.config.hidden_size)
265+
self.config.hidden_size,
266+
quant_config=quant_config,
267+
prefix=f"{prefix}.lm_head")
263268
self.logits_processor = LogitsProcessor(config.vocab_size)
264269
self.sampler = get_sampler()
265270
self.make_empty_intermediate_tensors = (
@@ -304,7 +309,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
304309
params_dict = dict(self.named_parameters(remove_duplicate=False))
305310
loaded_params: Set[str] = set()
306311
for name, loaded_weight in weights:
307-
if "lm_head.weight" in name:
312+
if name.startswith("lm_head"):
308313
# GPT-2 ties the weights of the embedding layer and the final
309314
# linear layer.
310315
continue

vllm/model_executor/models/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def __init__(
156156
)
157157

158158
is_neox_style = True
159-
if quant_config is not None and quant_config.get_name() == "gguf":
159+
is_gguf = quant_config and quant_config.get_name() == "gguf"
160+
if is_gguf and config.model_type == "llama":
160161
is_neox_style = False
161162

162163
self.rotary_emb = get_rope(

0 commit comments

Comments
 (0)