Skip to content

Commit 5fac73f

Browse files
authored
Add weight padding for moe (#119)
* add weight padding for moe * enable padding by default * fix linter * fix linter * fix linter * using envs.py * fix linter
1 parent 42b1b9a commit 5fac73f

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
VERBOSE: bool = False
4141
VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1
4242
VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1
43+
VLLM_MOE_PADDING: bool = True
4344

4445
# The begin-* and end* here are used by the documentation generator
4546
# to extract the used env vars.
@@ -229,6 +230,10 @@
229230
# Poll for new requests every this many steps
230231
"VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS":
231232
lambda: int(os.getenv("VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS", "1")),
233+
234+
# Pad the weight for moe kernel or not
235+
"VLLM_MOE_PADDING":
236+
lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))),
232237
}
233238

234239
# end-env-vars-definition

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010

1111
import vllm._moe_C as moe_kernels
1212
from vllm import _custom_ops as ops
13+
from vllm import envs
1314
from vllm.logger import init_logger
1415

1516
logger = init_logger(__name__)
17+
padding_size = 128 if envs.VLLM_MOE_PADDING else 0
1618

1719

1820
@triton.jit
@@ -262,7 +264,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
262264
expert_ids,
263265
num_tokens_post_padded,
264266
B.shape[1],
265-
B.shape[2],
267+
B.shape[2] - padding_size,
266268
sorted_token_ids.shape[0],
267269
topk_ids.numel(),
268270
A.stride(0),
@@ -365,7 +367,8 @@ def fused_experts(hidden_states: torch.Tensor,
365367
a1_scale: Optional[torch.Tensor] = None,
366368
a2_scale: Optional[torch.Tensor] = None):
367369
# Check constraints.
368-
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
370+
assert hidden_states.shape[
371+
1] == w1.shape[2] - padding_size, "Hidden size mismatch"
369372
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
370373
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
371374
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -381,7 +384,7 @@ def fused_experts(hidden_states: torch.Tensor,
381384
config = override_config
382385
else:
383386
# First try to load optimal config from the file
384-
configs = get_moe_configs(E, w2.shape[2],
387+
configs = get_moe_configs(E, w2.shape[2] - padding_size,
385388
"float8" if use_fp8 else None)
386389

387390
if configs:

vllm/model_executor/models/mixtral.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
from typing import Iterable, List, Optional, Tuple
2525

2626
import torch
27+
import torch.nn.functional as F
2728
from torch import nn
2829
from transformers import MixtralConfig
2930

3031
from vllm import _custom_ops as ops
32+
from vllm import envs
3133
from vllm.attention import Attention, AttentionMetadata
3234
from vllm.config import CacheConfig, LoRAConfig
3335
from vllm.distributed import (get_tensor_model_parallel_rank,
@@ -181,6 +183,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
181183
def process_weights_after_loading(self):
182184
# Fp8 is the only case where we need to process after loading.
183185
if not self.use_fp8:
186+
if envs.VLLM_MOE_PADDING:
187+
self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data,
188+
(0, 128), "constant", 0),
189+
requires_grad=False)
190+
self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data,
191+
(0, 128), "constant", 0),
192+
requires_grad=False)
184193
return
185194

186195
# If checkpoint is fp16, quantize here.

0 commit comments

Comments
 (0)