From 2cf6cf7fdaf6f8b790f22e4129aff754e0923a6b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 15 Jan 2025 07:49:49 +0800 Subject: [PATCH] [Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924) Signed-off-by: Jee Jee Li Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/model_loader/loader.py | 27 +++++++++------------- vllm/model_executor/model_loader/utils.py | 26 ++++++++++++++++++++- vllm/model_executor/models/baichuan.py | 7 ------ vllm/model_executor/models/exaone.py | 8 ------- vllm/model_executor/models/falcon.py | 6 ++--- vllm/model_executor/models/gemma.py | 9 -------- vllm/model_executor/models/gemma2.py | 10 -------- vllm/model_executor/models/granite.py | 8 ------- vllm/model_executor/models/idefics3.py | 10 -------- vllm/model_executor/models/llama.py | 10 -------- vllm/model_executor/models/llava.py | 12 ++++------ vllm/model_executor/models/minicpm.py | 10 -------- vllm/model_executor/models/minicpm3.py | 6 ----- vllm/model_executor/models/minicpmv.py | 20 ---------------- vllm/model_executor/models/mllama.py | 11 +++------ vllm/model_executor/models/molmo.py | 6 ----- vllm/model_executor/models/nemotron.py | 6 ----- vllm/model_executor/models/opt.py | 10 +++----- vllm/model_executor/models/phi.py | 8 ------- vllm/model_executor/models/phi3.py | 4 ---- vllm/model_executor/models/qwen.py | 7 ------ vllm/model_executor/models/qwen2.py | 10 -------- vllm/model_executor/models/qwen2_vl.py | 10 -------- vllm/model_executor/models/solar.py | 8 ------- 24 files changed, 49 insertions(+), 200 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0033fbff0e9ac..9fe0db62435a0 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -39,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (get_model_architecture, +from vllm.model_executor.model_loader.utils import (ParamMapping, + get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -983,21 +984,11 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, def _get_bnb_target_modules(self, model: nn.Module) -> None: - # TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with - # packed_modules_mapping. - inverse_stacked_mapping: Dict[str, List[str]] = {} - for orig, ( - packed, - idx, - ) in model.bitsandbytes_stacked_params_mapping.items(): - if packed not in inverse_stacked_mapping: - inverse_stacked_mapping[packed] = [] - inverse_stacked_mapping[packed].insert(idx, orig) - for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): last_name = name.split(".")[-1] - if sub_modules := inverse_stacked_mapping.get(last_name, []): + if sub_modules := self.modules_mapping.packed_mapping.get( + last_name, []): # Map vllm's names to transformers's names. for sub_name in sub_modules: self.target_modules.append( @@ -1018,15 +1009,19 @@ def _load_weights(self, model_config: ModelConfig, "The required method 'load_weights' is not defined in class" f" {type(model).__name__}.") - if not hasattr(model, "bitsandbytes_stacked_params_mapping"): + if not hasattr(model, "packed_modules_mapping"): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet.") + "quantization yet. No 'packed_modules_mapping' found.") + + self.modules_mapping = ParamMapping( + copy.deepcopy(model.packed_modules_mapping)) # For some models like Molmo, we need to use hf_to_vllm_mapper # to ensure correct loading of weights. if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + # Modules whose weights might have fused on disk # we need their output_sizes to make shard in flight correctly with TP self.maybe_fused_weights_modules: Dict[str, List[int]] = {} @@ -1109,7 +1104,7 @@ def _load_weights(self, model_config: ModelConfig, for shard_name, ( weight_name, index, - ) in model.bitsandbytes_stacked_params_mapping.items(): + ) in self.modules_mapping.inverse_packed_mapping.items(): shard_pos = quant_param_name.find(shard_name) # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 44978a55e072d..3f923d2f6632a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,7 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Tuple, Type +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Type import torch from torch import nn @@ -49,3 +50,26 @@ def get_model_architecture( def get_architecture_class_name(model_config: ModelConfig) -> str: return get_model_architecture(model_config)[1] + + +@dataclass +class ParamMapping: + """ + A class to handle parameter mapping for model weight loading. + It creates a bidirectional mapping between packed parameters and their + constituent parts. + """ + packed_mapping: Dict[str, List[str]] + inverse_packed_mapping: Dict[str, Tuple[str, + int]] = field(default_factory=dict) + + def __post_init__(self): + for packed_name, sub_params in self.packed_mapping.items(): + # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]}) + if len(sub_params) == 1 and sub_params[0] == packed_name: + continue + for index, param_name in enumerate(sub_params): + self.inverse_packed_mapping[param_name] = ( + packed_name, + index, + ) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5e68b7f165bf4..a923ed36a9db2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def __init__( self, *, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 8324a563edd64..ad15f835b1609 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "c_fc_0": ("gate_up_proj", 0), - "c_fc_1": ("gate_up_proj", 1), - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8660cf79b9cdb..c503a368e8244 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -409,9 +409,9 @@ def forward( class FalconForCausalLM(nn.Module, SupportsPP): - - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = {} + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index b28715c48adfb..6de0c866bc2f0 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "gate_up_proj", "down_proj", ] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f4530e4771960..698b9a5b6b1d6 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index a91ed4158a73f..3e95926fd1e22 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 43d2777d32a72..85ea6508445c7 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -820,16 +820,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, "down_proj", ] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 17b0fbb777e8e..16fa7acf54fdc 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - # Mistral/Llama models can also be loaded with --load-format mistral # from consolidated.safetensors checkpoints mistral_mapping = { diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index bb3db60c7d8ed..722fff98d5c19 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -463,14 +463,10 @@ def init_vision_tower_for_llava( info=_build_llava_or_pixtral_hf_info, dummy_inputs=LlavaDummyInputsBuilder) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 5a0f202364f26..6254d26c7060d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index e9d7eada1d16c..5e1e6c6fa6141 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): # `embedding_modules` and `embedding_padding_modules` # are inherited from MiniCPMForCausalLM - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ff7dab89e4da8..1aa529056893b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): "kv_proj", ] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - embedding_modules = {} embedding_padding_modules = [] @@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): "kv_proj", ] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 593a4d3fb6940..b2368ffff5412 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1107,14 +1107,9 @@ def forward( @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index c45ee9b921c9e..a2fd1701316f2 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - "gate_proj": ("merged_linear", 0), - "up_proj": ("merged_linear", 1), - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 34cb9981c167b..8cc62d5c803cc 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 7edafcd20b5db..ea1185aa80dc6 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -329,13 +329,9 @@ def forward( class OPTForCausalLM(nn.Module, SupportsPP): - - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f9e972688ddd1..59b7508a370f8 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "fc2", ] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - } - embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index 937858ee3b8c2..34141511ea791 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM): "gate_up_proj", ], } - - # BitandBytes specific attributes - # Initialize an empty dict when there is no stacked parameter mapping. - bitsandbytes_stacked_params_mapping = {} diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index baf955f6b515d..1345b381f0a99 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1028,13 +1028,6 @@ class QWenLLM(QWenBaseModel): embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "w2": ("gate_up_proj", 0), - "w1": ("gate_up_proj", 1), - } - class QWenVL(QWenBaseModel, SupportsMultiModal): packed_modules_mapping = { diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index d20fb150f7e39..0a99c87470850 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -418,16 +418,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 76a810e8f0c20..d00e5d362c8bc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1038,16 +1038,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index a7cf65a0e36e4..e83d316f74de2 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -401,14 +401,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__()