Skip to content

Commit

Permalink
[Model] Deepseek GGUF support (vllm-project#13167)
Browse files Browse the repository at this point in the history
  • Loading branch information
SzymonOzog authored Feb 27, 2025
1 parent edf309e commit 7f0be2a
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 10 deletions.
7 changes: 7 additions & 0 deletions docs/source/features/quantization/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
:::

GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path

```console
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
```

You can also use the GGUF model directly through the LLM entrypoint:

```python
Expand Down
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
self.hf_config_path = hf_config_path
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
Expand Down Expand Up @@ -321,8 +323,9 @@ def __init__(
if self.enable_sleep_mode and not current_platform.is_cuda():
raise ValueError("Sleep mode is only supported on CUDA devices.")

hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format)
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)

if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
Expand Down Expand Up @@ -947,7 +950,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
hf_config_path: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
Expand Down Expand Up @@ -262,6 +263,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=nullable_str,
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
Expand Down Expand Up @@ -1076,6 +1083,7 @@ def create_model_config(self) -> ModelConfig:

return ModelConfig(
model=self.model,
hf_config_path=self.hf_config_path,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
Expand Down
22 changes: 21 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Optional, Tuple

import torch
from torch.nn.parameter import UninitializedParameter

import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -514,7 +515,12 @@ def weight_loader(self, param: torch.nn.Parameter,
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
Expand All @@ -524,6 +530,20 @@ def weight_loader(self, param: torch.nn.Parameter,
if is_transposed:
shard_dim = int(not shard_dim)

full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[
shard_dim] // get_tensor_model_parallel_world_size()
param.materialize(final_shape, dtype=loaded_weight.dtype)

expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,23 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
# Special case for GGUF

is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param.size() == loaded_weight.size()
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)

def forward(self,
Expand Down
127 changes: 125 additions & 2 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand All @@ -29,7 +32,7 @@ def get_name(self) -> str:
return "gguf"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.half]

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -49,6 +52,8 @@ def get_quant_method(self, layer: torch.nn.Module,
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return None


Expand Down Expand Up @@ -184,6 +189,124 @@ def apply(self,
return out


class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""

def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
hidden_size)
#gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)

w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w13_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)

tensor_shape = (num_experts, intermediate_size_per_partition,
hidden_size)
#gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)

w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w2_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})

set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]

out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)

expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states


class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,9 +1245,24 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"

arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
Expand All @@ -1258,10 +1273,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig):
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(config)
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code)
state_dict = dummy_model.state_dict()

gguf_to_hf_name_map = {}
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]

if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
Expand Down

0 comments on commit 7f0be2a

Please sign in to comment.