Skip to content

Commit c84b2ae

Browse files
NickLucchemzusman
authored andcommitted
[BugFix] Fix GGUF tp>1 when vocab_size is not divisible by 64 (vllm-project#12230)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 84ebb34 commit c84b2ae

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

tests/models/decoder_only/language/test_gguf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,20 @@ def gguf_model(self):
6666
gguf_filename="starcoder2-3b.Q6_K.gguf",
6767
)
6868

69+
DOLPHIN_CONFIG = GGUFTestConfig(
70+
# Test VocabParallelEmbedding sharding issue.
71+
original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
72+
gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
73+
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
74+
)
75+
6976
MODELS = [
7077
LLAMA_CONFIG,
7178
QWEN2_CONFIG,
7279
PHI3_CONFIG,
7380
GPT2_CONFIG,
7481
STABLELM_CONFIG,
82+
DOLPHIN_CONFIG
7583
# STARCODER_CONFIG, # broken
7684
]
7785

@@ -107,6 +115,7 @@ def test_models(
107115

108116
# Run unquantized model.
109117
with vllm_runner(model_name=model.original_model,
118+
enforce_eager=True, # faster tests
110119
dtype=dtype,
111120
max_model_len=MAX_MODEL_LEN,
112121
tensor_parallel_size=tp_size) as original_model:
@@ -115,6 +124,7 @@ def test_models(
115124

116125
# Run gguf model.
117126
with vllm_runner(model_name=model.gguf_model,
127+
enforce_eager=True,
118128
tokenizer_name=model.original_model,
119129
dtype=dtype,
120130
max_model_len=MAX_MODEL_LEN,

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
355355
elif isinstance(param, UninitializedParameter):
356356
shape = list(loaded_weight.shape)
357357
if output_dim is not None:
358-
shape[output_dim] = shape[output_dim] // self.tp_size
358+
shape[output_dim] = self.num_embeddings_per_partition
359359
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
360360

361361
# If parameter does not have output dim, then it should
@@ -381,7 +381,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
381381
else:
382382
assert loaded_weight.shape[output_dim] == self.org_vocab_size
383383

384-
# Copy the data.
384+
# Copy the data. Select chunk corresponding to current shard.
385385
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
386386

387387
if current_platform.is_hpu():

0 commit comments

Comments
 (0)