From 07c4353057c2abb65db87104d44973c8cf5833bf Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 25 Feb 2025 20:07:12 -0500 Subject: [PATCH] [Model] Support Grok1 (#13795) Signed-off-by: mgoin --- docs/source/models/supported_models.md | 5 + tests/models/registry.py | 2 + .../layers/fused_moe/fused_moe.py | 43 +- vllm/model_executor/layers/fused_moe/layer.py | 22 +- .../layers/quantization/awq_marlin.py | 2 + .../compressed_tensors_moe.py | 4 + .../layers/quantization/experts_int8.py | 2 + .../model_executor/layers/quantization/fp8.py | 2 + .../layers/quantization/gptq_marlin.py | 3 + vllm/model_executor/models/grok1.py | 565 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 11 files changed, 634 insertions(+), 17 deletions(-) create mode 100644 vllm/model_executor/models/grok1.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index ae851c35e6266..9959f7233e868 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ * `parasail-ai/GritLM-7B-vllm`. * ✅︎ * ✅︎ +- * `Grok1ModelForCausalLM` + * Grok1 + * `hpcai-tech/grok-1`. + * ✅︎ + * ✅︎ - * `InternLMForCausalLM` * InternLM * `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index d89a41dae3aa5..566a4418feb1d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -130,6 +130,8 @@ def check_available_online( "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), + "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", + trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bc9573b36df76..00260313e72eb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, - global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + activation, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def inplace_fused_experts_fake( @@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1093,6 +1096,7 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1106,7 +1110,7 @@ def outplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, use_fp8_w8a8, use_int8_w8a16, + False, activation, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor, if inplace: torch.ops.vllm.inplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + hidden_states, w1, w2, topk_weights, topk_ids, activation, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) return hidden_states else: return torch.ops.vllm.outplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, - use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) + hidden_states, w1, w2, topk_weights, topk_ids, activation, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") invoke_fused_moe_kernel(intermediate_cache2, w2, @@ -1339,6 +1354,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + activation: str = "silu", use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -1370,6 +1386,8 @@ def fused_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk @@ -1420,6 +1438,7 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, + activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 452f390f49877..42554b61f67ab 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -120,7 +120,8 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -134,7 +135,8 @@ def apply( expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + activation=activation) def forward_cuda( self, @@ -150,7 +152,8 @@ def forward_cuda( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -170,6 +173,7 @@ def forward_cuda( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, global_num_experts=global_num_experts, expert_map=expert_map) @@ -186,9 +190,11 @@ def forward_cpu( global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, + activation: str = "silu", **kwargs, ): assert custom_routing_function is None + assert activation == "silu", f"{activation} is not supported." return layer.ipex_fusion( x, use_grouped_topk, @@ -213,7 +219,8 @@ def forward_tpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None @@ -225,6 +232,7 @@ def forward_tpu( if e_score_correction_bias is not None: raise NotImplementedError( "Expert score correction bias is not supported for TPU.") + assert activation == "silu", f"{activation} is not supported for TPU." return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -277,6 +285,7 @@ def __init__( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ): super().__init__() @@ -305,6 +314,7 @@ def __init__( self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias + self.activation = activation self.expert_map = None if self.ep_size > 1: @@ -653,7 +663,9 @@ def forward(self, hidden_states: torch.Tensor, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias) + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 7a2fb203dec35..473816fcc3ecd 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -469,7 +469,9 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." if expert_map is not None: raise NotImplementedError( "Expert Parallelism is not supported for " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f1f316f083396..c9aa0ec285baf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -219,6 +219,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -240,6 +241,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, global_num_experts=global_num_experts, expert_map=expert_map, @@ -550,7 +552,9 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." if expert_map is not None: raise NotImplementedError( "Expert Parallelism is not supported for " diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 0767926ee5c0c..d18ca55afebdb 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -113,6 +113,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -134,6 +135,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_int8_w8a16=True, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5e1bec0bb4be2..76a7d4df8a369 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -675,6 +675,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -698,6 +699,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94a1de71bbca0..21db8ccba059c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -590,7 +590,10 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + # The input must currently be float16 orig_dtype = x.dtype x = x.half() diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 0000000000000..f2e82017f6530 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +# Default Grok1-specific constants, overridden by config values if present +DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 +DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257 +DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.config = config # Store config reference + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + attn_logits_soft_cap = max( + getattr(config, "attn_logit_softcapping", 30.0), 0.0) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + + # Apply attention output multiplier if specified in config + attn_multiplier = getattr(self.config, "attn_output_multiplier", + None) if self.config else None + if attn_multiplier is not None: + output = output * attn_multiplier + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Check for fp8 quantization + self.use_fp8 = False + if quant_config is not None: + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", + lambda: False)() + if not self.use_fp8 and hasattr(quant_config, "is_fp8"): + self.use_fp8 = quant_config.is_fp8 + + # Requires transformers > 4.32.0 + # Default rope_theta value if not in config + rope_theta = 10000 + self.attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + config=config) # Pass config to Grok1Attention + + # Grok1 uses "num_experts" in its config + num_experts = getattr(config, "num_experts", 8) + num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) + + self.moe_block = Grok1MoE(num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, residual) + + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Post attention normalization + hidden_states = self.post_attn_norm(hidden_states) + + # MoE block with normalization + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + hidden_states = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Grok1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( + config, "embedding_multiplier_scale", + DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = Grok1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.output_multiplier_scale = getattr( + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + self.output_multiplier_scale) + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # Handle Grok1-specific norm.scale naming + if "norm.scale" in name: + name = name.replace("scale", "weight") + + # Skip lm_head when tie_word_embeddings is True + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 05fb3d21953dd..58155905a7b70 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -60,6 +60,7 @@ "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),