Skip to content

Commit 7193774

Browse files
authored
[Misc] Support quantization of MllamaForCausalLM (vllm-project#8822)
1 parent e2c6e0a commit 7193774

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/models/mllama.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ def __init__(
624624
self,
625625
config: Optional[config_mllama.MllamaTextConfig] = None,
626626
layer_idx: Optional[int] = None,
627+
quant_config: Optional[QuantizationConfig] = None,
627628
):
628629
super().__init__()
629630
self.config = config
@@ -648,12 +649,14 @@ def __init__(
648649
self.num_heads,
649650
self.num_key_value_heads,
650651
bias=False,
652+
quant_config=quant_config,
651653
)
652654
self.o_proj = RowParallelLinear(
653655
self.num_heads * self.head_dim,
654656
self.hidden_size,
655657
bias=False,
656658
input_is_parallel=True,
659+
quant_config=quant_config,
657660
)
658661
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
659662
# use huggingface's instead
@@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
708711
"""Cross-attention transformer block with tanh-gated attention
709712
and feedforward."""
710713

711-
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
714+
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
715+
quant_config: Optional[QuantizationConfig]) \
712716
-> None:
713717
super().__init__()
714718
self.layer_idx = layer_idx
715719
self.cross_attn = MllamaTextCrossAttention(
716720
config=config,
717721
layer_idx=layer_idx,
722+
quant_config=quant_config,
718723
)
719724

720725
self.input_layernorm = RMSNorm(config.hidden_size,
@@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
725730
hidden_size=config.hidden_size,
726731
intermediate_size=config.intermediate_size,
727732
hidden_act=config.hidden_act,
733+
quant_config=quant_config,
728734
)
729735
self.post_attention_layernorm = RMSNorm(config.hidden_size,
730736
eps=config.rms_norm_eps)
@@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig,
780786
for layer_idx in range(config.num_hidden_layers):
781787
if layer_idx in self.cross_attention_layers:
782788
layers.append(
783-
MllamaCrossAttentionDecoderLayer(config, layer_idx))
789+
MllamaCrossAttentionDecoderLayer(
790+
config, layer_idx, quant_config=quant_config))
784791
else:
785792
# TODO: force LlamaDecoderLayer to config.attention_bias=False
786793
layers.append(

0 commit comments

Comments
 (0)