From 65cdc0ddeeb22354d1786f842daca55713f609ad Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 25 Sep 2024 17:46:22 -0400 Subject: [PATCH] [Misc] Support quantization of MllamaForCausalLM (#8822) Signed-off-by: Alvant --- vllm/model_executor/models/mllama.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index aa868a3b8da28..45d6ad3c0efa5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -624,6 +624,7 @@ def __init__( self, config: Optional[config_mllama.MllamaTextConfig] = None, layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -648,12 +649,14 @@ def __init__( self.num_heads, self.num_key_value_heads, bias=False, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.hidden_size, bias=False, input_is_parallel=True, + quant_config=quant_config, ) # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # use huggingface's instead @@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" - def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, + quant_config: Optional[QuantizationConfig]) \ -> None: super().__init__() self.layer_idx = layer_idx self.cross_attn = MllamaTextCrossAttention( config=config, layer_idx=layer_idx, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, @@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + quant_config=quant_config, ) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig, for layer_idx in range(config.num_hidden_layers): if layer_idx in self.cross_attention_layers: layers.append( - MllamaCrossAttentionDecoderLayer(config, layer_idx)) + MllamaCrossAttentionDecoderLayer( + config, layer_idx, quant_config=quant_config)) else: # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append(