Skip to content

Commit

Permalink
[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_m…
Browse files Browse the repository at this point in the history
…apping (vllm-project#11924)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: hzh <hezhihui_thu@163.com>
  • Loading branch information
jeejeelee authored and HwwwwwwwH committed Jan 22, 2025
1 parent cc9cde5 commit 58d45cd
Show file tree
Hide file tree
Showing 24 changed files with 49 additions and 200 deletions.
27 changes: 11 additions & 16 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
from vllm.model_executor.model_loader.utils import (ParamMapping,
get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
Expand Down Expand Up @@ -983,21 +984,11 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,

def _get_bnb_target_modules(self, model: nn.Module) -> None:

# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping: Dict[str, List[str]] = {}
for orig, (
packed,
idx,
) in model.bitsandbytes_stacked_params_mapping.items():
if packed not in inverse_stacked_mapping:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)

for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
if sub_modules := self.modules_mapping.packed_mapping.get(
last_name, []):
# Map vllm's names to transformers's names.
for sub_name in sub_modules:
self.target_modules.append(
Expand All @@ -1018,15 +1009,19 @@ def _load_weights(self, model_config: ModelConfig,
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")

if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
"quantization yet. No 'packed_modules_mapping' found.")

self.modules_mapping = ParamMapping(
copy.deepcopy(model.packed_modules_mapping))

# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)

# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
Expand Down Expand Up @@ -1109,7 +1104,7 @@ def _load_weights(self, model_config: ModelConfig,
for shard_name, (
weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items():
) in self.modules_mapping.inverse_packed_mapping.items():
shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
Expand Down
26 changes: 25 additions & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Tuple, Type
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type

import torch
from torch import nn
Expand Down Expand Up @@ -49,3 +50,26 @@ def get_model_architecture(

def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]


@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: Dict[str, List[str]]
inverse_packed_mapping: Dict[str, Tuple[str,
int]] = field(default_factory=dict)

def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
7 changes: 0 additions & 7 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(
self,
*,
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"c_fc_0": ("gate_up_proj", 0),
"c_fc_1": ("gate_up_proj", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,9 @@ def forward(


class FalconForCausalLM(nn.Module, SupportsPP):

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
9 changes: 0 additions & 9 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj",
"down_proj",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
"down_proj",
]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

embedding_modules = {}
embedding_padding_modules = []

Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
Expand Down
12 changes: 4 additions & 8 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,10 @@ def init_vision_tower_for_llava(
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
20 changes: 0 additions & 20 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,16 +1165,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj",
]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

embedding_modules = {}
embedding_padding_modules = []

Expand Down Expand Up @@ -1285,16 +1275,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
"kv_proj",
]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

embedding_modules = {}
embedding_padding_modules = []

Expand Down
11 changes: 3 additions & 8 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,9 @@ def forward(
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
embedding_modules = {}
embedding_padding_modules = []

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
"gate_proj": ("merged_linear", 0),
"up_proj": ("merged_linear", 1),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,9 @@ def forward(


class OPTForCausalLM(nn.Module, SupportsPP):

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc2",
]

# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}

embedding_modules = {}
embedding_padding_modules = []

Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj",
],
}

# BitandBytes specific attributes
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}
Loading

0 comments on commit 58d45cd

Please sign in to comment.