@@ -624,6 +624,7 @@ def __init__(
624
624
self ,
625
625
config : Optional [config_mllama .MllamaTextConfig ] = None ,
626
626
layer_idx : Optional [int ] = None ,
627
+ quant_config : Optional [QuantizationConfig ] = None ,
627
628
):
628
629
super ().__init__ ()
629
630
self .config = config
@@ -648,12 +649,14 @@ def __init__(
648
649
self .num_heads ,
649
650
self .num_key_value_heads ,
650
651
bias = False ,
652
+ quant_config = quant_config ,
651
653
)
652
654
self .o_proj = RowParallelLinear (
653
655
self .num_heads * self .head_dim ,
654
656
self .hidden_size ,
655
657
bias = False ,
656
658
input_is_parallel = True ,
659
+ quant_config = quant_config ,
657
660
)
658
661
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
659
662
# use huggingface's instead
@@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
708
711
"""Cross-attention transformer block with tanh-gated attention
709
712
and feedforward."""
710
713
711
- def __init__ (self , config : config_mllama .MllamaTextConfig , layer_idx : int ) \
714
+ def __init__ (self , config : config_mllama .MllamaTextConfig , layer_idx : int ,
715
+ quant_config : Optional [QuantizationConfig ]) \
712
716
-> None :
713
717
super ().__init__ ()
714
718
self .layer_idx = layer_idx
715
719
self .cross_attn = MllamaTextCrossAttention (
716
720
config = config ,
717
721
layer_idx = layer_idx ,
722
+ quant_config = quant_config ,
718
723
)
719
724
720
725
self .input_layernorm = RMSNorm (config .hidden_size ,
@@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
725
730
hidden_size = config .hidden_size ,
726
731
intermediate_size = config .intermediate_size ,
727
732
hidden_act = config .hidden_act ,
733
+ quant_config = quant_config ,
728
734
)
729
735
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
730
736
eps = config .rms_norm_eps )
@@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig,
780
786
for layer_idx in range (config .num_hidden_layers ):
781
787
if layer_idx in self .cross_attention_layers :
782
788
layers .append (
783
- MllamaCrossAttentionDecoderLayer (config , layer_idx ))
789
+ MllamaCrossAttentionDecoderLayer (
790
+ config , layer_idx , quant_config = quant_config ))
784
791
else :
785
792
# TODO: force LlamaDecoderLayer to config.attention_bias=False
786
793
layers .append (
0 commit comments