forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ Misc ] Refactor MoE to isolate Fp8 From Mixtral (vllm-project#5970)
Co-authored-by: Robert Shaw <rshaw@neuralmagic> Co-authored-by: Michael Goin <michael@neuralmagic.com>
- Loading branch information
1 parent
8a84d22
commit 6cb32da
Showing
10 changed files
with
537 additions
and
306 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8 | ||
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.86 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.86 | ||
limit: 250 | ||
num_fewshot: 5 |
11 changes: 11 additions & 0 deletions
11
.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4 | ||
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.624 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.624 | ||
limit: 250 | ||
num_fewshot: 5 |
11 changes: 11 additions & 0 deletions
11
.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 | ||
model_name: "Qwen/Qwen2-57B-A14B-Instruct" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.792 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.824 | ||
limit: 250 | ||
num_fewshot: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
Meta-Llama-3-70B-Instruct.yaml | ||
Mixtral-8x7B-Instruct-v0.1.yaml | ||
Qwen2-57B-A14-Instruct.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) | ||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, | ||
FusedMoEMethodBase) | ||
|
||
__all__ = [ | ||
"fused_moe", | ||
"fused_topk", | ||
"fused_experts", | ||
"get_config_file_name", | ||
"grouped_topk", | ||
"FusedMoE", | ||
"FusedMoEMethodBase", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
from abc import abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from vllm.distributed import (get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size, | ||
tensor_model_parallel_all_reduce) | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig, QuantizeMethodBase) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class FusedMoEMethodBase(QuantizeMethodBase): | ||
|
||
@abstractmethod | ||
def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||
hidden_size: int, intermediate_size: int, | ||
params_dtype: torch.dtype, **extra_weight_attrs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def apply(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
router_logits: torch.Tensor, | ||
top_k: int, | ||
renormalize: bool = True) -> torch.Tensor: | ||
raise NotImplementedError | ||
|
||
|
||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase): | ||
"""MoE method without quantization.""" | ||
|
||
def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||
hidden_size: int, intermediate_size: int, | ||
params_dtype: torch.dtype, **extra_weight_attrs): | ||
|
||
# Fused gate_up_proj (column parallel) | ||
w13_weight = torch.nn.Parameter(torch.empty(num_experts, | ||
2 * intermediate_size, | ||
hidden_size, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
layer.register_parameter("w13_weight", w13_weight) | ||
set_weight_attrs(w13_weight, extra_weight_attrs) | ||
|
||
# down_proj (row parallel) | ||
w2_weight = torch.nn.Parameter(torch.empty(num_experts, | ||
hidden_size, | ||
intermediate_size, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
layer.register_parameter("w2_weight", w2_weight) | ||
set_weight_attrs(w2_weight, extra_weight_attrs) | ||
|
||
def apply(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
router_logits: torch.Tensor, | ||
top_k: int, | ||
renormalize: bool = True) -> torch.Tensor: | ||
|
||
return fused_moe(x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
router_logits, | ||
top_k, | ||
renormalize=renormalize, | ||
inplace=True) | ||
|
||
|
||
class FusedMoE(torch.nn.Module): | ||
"""FusedMoE layer for MoE models. | ||
This layer contains both MergedColumnParallel weights (gate_up_proj / | ||
w13) and RowParallelLinear weights (down_proj/ w2). | ||
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We | ||
copy that naming convention here and handle any remapping in the | ||
load_weights function in each model implementation. | ||
Args: | ||
num_experts: Number of experts in the model | ||
top_k: Number of experts selected for each token | ||
hidden_size: Input hidden state size of the transformer | ||
intermediate_size: Intermediate size of the experts | ||
params_dtype: Data type for the parameters. | ||
reduce_results: Whether to all all_reduce on the output of the layer | ||
renomalize: Whether to renormalize the logits in the fused_moe kernel | ||
quant_config: Quantization configure. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_experts: int, | ||
top_k: int, | ||
hidden_size: int, | ||
intermediate_size: int, | ||
params_dtype: Optional[torch.dtype] = None, | ||
reduce_results: bool = False, | ||
renormalize: bool = True, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
tp_size: Optional[int] = None, | ||
): | ||
super().__init__() | ||
|
||
if params_dtype is None: | ||
params_dtype = torch.get_default_dtype() | ||
|
||
self.tp_size = (tp_size if tp_size is not None else | ||
get_tensor_model_parallel_world_size()) | ||
self.top_k = top_k | ||
self.num_experts = num_experts | ||
self.intermediate_size_per_partition = intermediate_size // self.tp_size | ||
self.reduce_results = reduce_results | ||
self.renormalize = renormalize | ||
|
||
if quant_config is None: | ||
self.quant_method: Optional[QuantizeMethodBase] = ( | ||
UnquantizedFusedMoEMethod()) | ||
else: | ||
self.quant_method = quant_config.get_quant_method(self) | ||
assert self.quant_method is not None | ||
|
||
self.quant_method.create_weights( | ||
layer=self, | ||
num_experts=num_experts, | ||
hidden_size=hidden_size, | ||
intermediate_size=self.intermediate_size_per_partition, | ||
params_dtype=params_dtype, | ||
weight_loader=self.weight_loader) | ||
|
||
def weight_loader(self, param: torch.nn.Parameter, | ||
loaded_weight: torch.Tensor, weight_name: str, | ||
shard_id: int, expert_id: int): | ||
param_data = param.data | ||
|
||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. | ||
# Follow up PR to enable fp8 for other MoE models. | ||
if "input_scale" in weight_name or "w2.weight_scale" in weight_name: | ||
if param_data[expert_id] != 1 and (param_data[expert_id] - | ||
loaded_weight).abs() > 1e-5: | ||
raise ValueError( | ||
"input_scales of w1 and w3 of a layer " | ||
f"must be equal. But got {param_data[expert_id]} " | ||
f"vs. {loaded_weight}") | ||
param_data[expert_id] = loaded_weight | ||
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. | ||
# Follow up PR to enable fp8 for other MoE models. | ||
elif "weight_scale" in weight_name: | ||
# We have to keep the weight scales of w1 and w3 because | ||
# we need to re-quantize w1/w3 weights after weight loading. | ||
assert "w1" in weight_name or "w3" in weight_name | ||
shard_id = 0 if "w1" in weight_name else 1 | ||
param_data[expert_id][shard_id] = loaded_weight | ||
else: | ||
tp_rank = get_tensor_model_parallel_rank() | ||
shard_size = self.intermediate_size_per_partition | ||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) | ||
|
||
# w1, gate_proj case: Load into first shard of w13. | ||
if shard_id == 0: | ||
param_data[expert_id, | ||
0:shard_size, :] = loaded_weight[shard, :] | ||
# w3, up_proj case: Load into second shard of w13. | ||
elif shard_id == 2: | ||
param_data[expert_id, shard_size:2 * | ||
shard_size, :] = loaded_weight[shard, :] | ||
# w2, down_proj case: Load into only shard of w2. | ||
elif shard_id == 1: | ||
param_data[expert_id, :, :] = loaded_weight[:, shard] | ||
else: | ||
raise ValueError( | ||
f"Shard id must be in [0,1,2] but got {shard_id}") | ||
|
||
def forward(self, hidden_states: torch.Tensor, | ||
router_logits: torch.Tensor): | ||
assert self.quant_method is not None | ||
|
||
# Matrix multiply. | ||
final_hidden_states = self.quant_method.apply( | ||
self, | ||
x=hidden_states, | ||
router_logits=router_logits, | ||
top_k=self.top_k, | ||
renormalize=self.renormalize) | ||
|
||
if self.reduce_results and self.tp_size > 1: | ||
final_hidden_states = tensor_model_parallel_all_reduce( | ||
final_hidden_states) | ||
|
||
return final_hidden_states |
Oops, something went wrong.