Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 4d31612

Browse files
robertgshaw2-redhatRobert Shawmgoin
authored andcommitted
[ Misc ] Refactor MoE to isolate Fp8 From Mixtral (vllm-project#5970)
Co-authored-by: Robert Shaw <rshaw@neuralmagic> Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent d377439 commit 4d31612

File tree

10 files changed

+537
-306
lines changed

10 files changed

+537
-306
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
2+
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.86
8+
- name: "exact_match,flexible-extract"
9+
value: 0.86
10+
limit: 250
11+
num_fewshot: 5
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
2+
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.624
8+
- name: "exact_match,flexible-extract"
9+
value: 0.624
10+
limit: 250
11+
num_fewshot: 5
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
2+
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.792
8+
- name: "exact_match,flexible-extract"
9+
value: 0.824
10+
limit: 250
11+
num_fewshot: 5
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
Meta-Llama-3-70B-Instruct.yaml
22
Mixtral-8x7B-Instruct-v0.1.yaml
3+
Qwen2-57B-A14-Instruct.yaml

tests/kernels/test_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def test_mixtral_moe(dtype: torch.dtype):
8383
for i in range(config.num_local_experts):
8484
weights = (hf_moe.experts[i].w1.weight.data,
8585
hf_moe.experts[i].w3.weight.data)
86-
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
87-
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
86+
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
87+
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
8888

8989
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
9090
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from vllm.model_executor.layers.fused_moe.fused_moe import (
22
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
3+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
4+
FusedMoEMethodBase)
35

46
__all__ = [
57
"fused_moe",
68
"fused_topk",
79
"fused_experts",
810
"get_config_file_name",
911
"grouped_topk",
12+
"FusedMoE",
13+
"FusedMoEMethodBase",
1014
]
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from abc import abstractmethod
2+
from typing import Optional
3+
4+
import torch
5+
6+
from vllm.distributed import (get_tensor_model_parallel_rank,
7+
get_tensor_model_parallel_world_size,
8+
tensor_model_parallel_all_reduce)
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
11+
from vllm.model_executor.layers.quantization.base_config import (
12+
QuantizationConfig, QuantizeMethodBase)
13+
from vllm.model_executor.utils import set_weight_attrs
14+
15+
logger = init_logger(__name__)
16+
17+
18+
class FusedMoEMethodBase(QuantizeMethodBase):
19+
20+
@abstractmethod
21+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
22+
hidden_size: int, intermediate_size: int,
23+
params_dtype: torch.dtype, **extra_weight_attrs):
24+
raise NotImplementedError
25+
26+
@abstractmethod
27+
def apply(self,
28+
layer: torch.nn.Module,
29+
x: torch.Tensor,
30+
router_logits: torch.Tensor,
31+
top_k: int,
32+
renormalize: bool = True) -> torch.Tensor:
33+
raise NotImplementedError
34+
35+
36+
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
37+
"""MoE method without quantization."""
38+
39+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
40+
hidden_size: int, intermediate_size: int,
41+
params_dtype: torch.dtype, **extra_weight_attrs):
42+
43+
# Fused gate_up_proj (column parallel)
44+
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
45+
2 * intermediate_size,
46+
hidden_size,
47+
dtype=params_dtype),
48+
requires_grad=False)
49+
layer.register_parameter("w13_weight", w13_weight)
50+
set_weight_attrs(w13_weight, extra_weight_attrs)
51+
52+
# down_proj (row parallel)
53+
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
54+
hidden_size,
55+
intermediate_size,
56+
dtype=params_dtype),
57+
requires_grad=False)
58+
layer.register_parameter("w2_weight", w2_weight)
59+
set_weight_attrs(w2_weight, extra_weight_attrs)
60+
61+
def apply(self,
62+
layer: torch.nn.Module,
63+
x: torch.Tensor,
64+
router_logits: torch.Tensor,
65+
top_k: int,
66+
renormalize: bool = True) -> torch.Tensor:
67+
68+
return fused_moe(x,
69+
layer.w13_weight,
70+
layer.w2_weight,
71+
router_logits,
72+
top_k,
73+
renormalize=renormalize,
74+
inplace=True)
75+
76+
77+
class FusedMoE(torch.nn.Module):
78+
"""FusedMoE layer for MoE models.
79+
80+
This layer contains both MergedColumnParallel weights (gate_up_proj /
81+
w13) and RowParallelLinear weights (down_proj/ w2).
82+
83+
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
84+
copy that naming convention here and handle any remapping in the
85+
load_weights function in each model implementation.
86+
87+
Args:
88+
num_experts: Number of experts in the model
89+
top_k: Number of experts selected for each token
90+
hidden_size: Input hidden state size of the transformer
91+
intermediate_size: Intermediate size of the experts
92+
params_dtype: Data type for the parameters.
93+
reduce_results: Whether to all all_reduce on the output of the layer
94+
renomalize: Whether to renormalize the logits in the fused_moe kernel
95+
quant_config: Quantization configure.
96+
"""
97+
98+
def __init__(
99+
self,
100+
num_experts: int,
101+
top_k: int,
102+
hidden_size: int,
103+
intermediate_size: int,
104+
params_dtype: Optional[torch.dtype] = None,
105+
reduce_results: bool = False,
106+
renormalize: bool = True,
107+
quant_config: Optional[QuantizationConfig] = None,
108+
tp_size: Optional[int] = None,
109+
):
110+
super().__init__()
111+
112+
if params_dtype is None:
113+
params_dtype = torch.get_default_dtype()
114+
115+
self.tp_size = (tp_size if tp_size is not None else
116+
get_tensor_model_parallel_world_size())
117+
self.top_k = top_k
118+
self.num_experts = num_experts
119+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
120+
self.reduce_results = reduce_results
121+
self.renormalize = renormalize
122+
123+
if quant_config is None:
124+
self.quant_method: Optional[QuantizeMethodBase] = (
125+
UnquantizedFusedMoEMethod())
126+
else:
127+
self.quant_method = quant_config.get_quant_method(self)
128+
assert self.quant_method is not None
129+
130+
self.quant_method.create_weights(
131+
layer=self,
132+
num_experts=num_experts,
133+
hidden_size=hidden_size,
134+
intermediate_size=self.intermediate_size_per_partition,
135+
params_dtype=params_dtype,
136+
weight_loader=self.weight_loader)
137+
138+
def weight_loader(self, param: torch.nn.Parameter,
139+
loaded_weight: torch.Tensor, weight_name: str,
140+
shard_id: int, expert_id: int):
141+
param_data = param.data
142+
143+
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
144+
# Follow up PR to enable fp8 for other MoE models.
145+
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
146+
if param_data[expert_id] != 1 and (param_data[expert_id] -
147+
loaded_weight).abs() > 1e-5:
148+
raise ValueError(
149+
"input_scales of w1 and w3 of a layer "
150+
f"must be equal. But got {param_data[expert_id]} "
151+
f"vs. {loaded_weight}")
152+
param_data[expert_id] = loaded_weight
153+
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
154+
# Follow up PR to enable fp8 for other MoE models.
155+
elif "weight_scale" in weight_name:
156+
# We have to keep the weight scales of w1 and w3 because
157+
# we need to re-quantize w1/w3 weights after weight loading.
158+
assert "w1" in weight_name or "w3" in weight_name
159+
shard_id = 0 if "w1" in weight_name else 1
160+
param_data[expert_id][shard_id] = loaded_weight
161+
else:
162+
tp_rank = get_tensor_model_parallel_rank()
163+
shard_size = self.intermediate_size_per_partition
164+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
165+
166+
# w1, gate_proj case: Load into first shard of w13.
167+
if shard_id == 0:
168+
param_data[expert_id,
169+
0:shard_size, :] = loaded_weight[shard, :]
170+
# w3, up_proj case: Load into second shard of w13.
171+
elif shard_id == 2:
172+
param_data[expert_id, shard_size:2 *
173+
shard_size, :] = loaded_weight[shard, :]
174+
# w2, down_proj case: Load into only shard of w2.
175+
elif shard_id == 1:
176+
param_data[expert_id, :, :] = loaded_weight[:, shard]
177+
else:
178+
raise ValueError(
179+
f"Shard id must be in [0,1,2] but got {shard_id}")
180+
181+
def forward(self, hidden_states: torch.Tensor,
182+
router_logits: torch.Tensor):
183+
assert self.quant_method is not None
184+
185+
# Matrix multiply.
186+
final_hidden_states = self.quant_method.apply(
187+
self,
188+
x=hidden_states,
189+
router_logits=router_logits,
190+
top_k=self.top_k,
191+
renormalize=self.renormalize)
192+
193+
if self.reduce_results and self.tp_size > 1:
194+
final_hidden_states = tensor_model_parallel_all_reduce(
195+
final_hidden_states)
196+
197+
return final_hidden_states

0 commit comments

Comments
 (0)