Skip to content

Commit

Permalink
[Misc] Support quantization of MllamaForCausalLM (vllm-project#8822)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
mgoin authored and garg-amit committed Oct 28, 2024
1 parent cf3a594 commit 9ad56f5
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
Expand All @@ -648,12 +649,14 @@ def __init__(
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
Expand Down Expand Up @@ -708,13 +711,15 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""

def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
quant_config: Optional[QuantizationConfig]) \
-> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
)

self.input_layernorm = RMSNorm(config.hidden_size,
Expand All @@ -725,6 +730,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -780,7 +786,8 @@ def __init__(self, config: config_mllama.MllamaTextConfig,
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(config, layer_idx))
MllamaCrossAttentionDecoderLayer(
config, layer_idx, quant_config=quant_config))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
Expand Down

0 comments on commit 9ad56f5

Please sign in to comment.