Skip to content

Commit

Permalink
[Model] Support Grok1 (vllm-project#13795)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <mgoin64@gmail.com>
  • Loading branch information
mgoin authored Feb 26, 2025
1 parent 34e3494 commit 07c4353
Show file tree
Hide file tree
Showing 11 changed files with 634 additions and 17 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `parasail-ai/GritLM-7B-vllm`.
* ✅︎
* ✅︎
- * `Grok1ModelForCausalLM`
* Grok1
* `hpcai-tech/grok-1`.
* ✅︎
* ✅︎
- * `InternLMForCausalLM`
* InternLM
* `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def check_available_online(
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
Expand Down
43 changes: 31 additions & 12 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand All @@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)


def inplace_fused_experts_fake(
Expand All @@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand Down Expand Up @@ -1093,6 +1096,7 @@ def outplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand All @@ -1106,7 +1110,7 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16,
False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
Expand All @@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand Down Expand Up @@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand All @@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,

if inplace:
torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)


def fused_experts_impl(hidden_states: torch.Tensor,
Expand All @@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
Expand Down Expand Up @@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)

torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")

invoke_fused_moe_kernel(intermediate_cache2,
w2,
Expand Down Expand Up @@ -1339,6 +1354,7 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
Expand Down Expand Up @@ -1370,6 +1386,8 @@ def fused_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
Expand Down Expand Up @@ -1420,6 +1438,7 @@ def fused_moe(
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
Expand Down
22 changes: 17 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def apply(
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
Expand All @@ -134,7 +135,8 @@ def apply(
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
activation=activation)

def forward_cuda(
self,
Expand All @@ -150,7 +152,8 @@ def forward_cuda(
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -170,6 +173,7 @@ def forward_cuda(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map)

Expand All @@ -186,9 +190,11 @@ def forward_cpu(
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
**kwargs,
):
assert custom_routing_function is None
assert activation == "silu", f"{activation} is not supported."
return layer.ipex_fusion(
x,
use_grouped_topk,
Expand All @@ -213,7 +219,8 @@ def forward_tpu(
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
Expand All @@ -225,6 +232,7 @@ def forward_tpu(
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
assert activation == "silu", f"{activation} is not supported for TPU."
return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down Expand Up @@ -277,6 +285,7 @@ def __init__(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()

Expand Down Expand Up @@ -305,6 +314,7 @@ def __init__(
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.activation = activation
self.expert_map = None

if self.ep_size > 1:
Expand Down Expand Up @@ -653,7 +663,9 @@ def forward(self, hidden_states: torch.Tensor,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
)

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

Expand All @@ -240,6 +241,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
Expand Down Expand Up @@ -550,7 +552,9 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

Expand All @@ -134,6 +135,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int8_w8a16=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

Expand All @@ -698,6 +699,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,10 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."

# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
Expand Down
Loading

0 comments on commit 07c4353

Please sign in to comment.