Skip to content

Commit e21a6fb

Browse files
tlrmchlsmthlulmer
authored andcommitted
[Distributed] Add enable_expert_parallel arg (vllm-project#14305)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent ce28a50 commit e21a6fb

File tree

5 files changed

+27
-21
lines changed

5 files changed

+27
-21
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# usage:
3-
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
4-
# python examples/offline_inference/data_parallel.py
3+
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
54
# we need to have a launcher to create multiple data parallel
65
# ranks. And each rank will create a vLLM instance to process its own prompts.
76
import os
@@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
5554
# Create an LLM.
5655
llm = LLM(model="ibm-research/PowerMoE-3b",
5756
tensor_parallel_size=GPUs_per_dp_rank,
58-
enforce_eager=True)
57+
enforce_eager=True,
58+
enable_expert_parallel=True)
5959
outputs = llm.generate(prompts, sampling_params)
6060
# Print the outputs.
6161
for output in outputs:

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def verify_with_parallel_config(
754754
" must be divisible by tensor parallel size "
755755
f"({tensor_parallel_size}).")
756756

757-
if envs.VLLM_TEST_ENABLE_EP:
757+
if parallel_config.enable_expert_parallel:
758758
self._verify_with_expert_parallelism()
759759

760760
pipeline_parallel_size = parallel_config.pipeline_parallel_size
@@ -1334,6 +1334,7 @@ class ParallelConfig:
13341334
# IP of the data parallel master.
13351335
data_parallel_master_ip: str = "127.0.0.1"
13361336
data_parallel_master_port: int = 29500 # Port of the data parallel master.
1337+
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
13371338

13381339
# Maximum number of multiple batches
13391340
# when load model sequentially. To avoid RAM OOM when using tensor

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class EngineArgs:
114114
# number of P/D disaggregation (or other disaggregation) workers
115115
pipeline_parallel_size: int = 1
116116
tensor_parallel_size: int = 1
117+
enable_expert_parallel: bool = False
117118
max_parallel_loading_workers: Optional[int] = None
118119
block_size: Optional[int] = None
119120
enable_prefix_caching: Optional[bool] = None
@@ -440,6 +441,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
440441
type=int,
441442
default=EngineArgs.tensor_parallel_size,
442443
help='Number of tensor parallel replicas.')
444+
parser.add_argument(
445+
'--enable-expert-parallel',
446+
action='store_true',
447+
help='Use expert parallelism instead of tensor parallelism '
448+
'for MoE layers.')
443449
parser.add_argument(
444450
'--max-parallel-loading-workers',
445451
type=int,
@@ -1207,6 +1213,7 @@ def create_engine_config(self,
12071213
parallel_config = ParallelConfig(
12081214
pipeline_parallel_size=self.pipeline_parallel_size,
12091215
tensor_parallel_size=self.tensor_parallel_size,
1216+
enable_expert_parallel=self.enable_expert_parallel,
12101217
max_parallel_loading_workers=self.max_parallel_loading_workers,
12111218
disable_custom_all_reduce=self.disable_custom_all_reduce,
12121219
tokenizer_pool_config=TokenizerPoolConfig.create_config(

vllm/envs.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@
8686
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
8787
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
8888
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
89-
VLLM_TEST_ENABLE_EP: bool = False
9089
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
9190
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
9291
VLLM_RAY_BUNDLE_INDICES: str = ""
@@ -579,12 +578,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
579578
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
580579
),
581580

582-
# If set, vLLM will use the experimental expert parallel implementation on
583-
# the FusedMoE layer, using tensor parallelism size as expert parallelism
584-
# size.
585-
"VLLM_TEST_ENABLE_EP":
586-
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),
587-
588581
# Number of GPUs per worker in Ray, if it is set to be a fraction,
589582
# it allows ray to schedule multiple actors on a single GPU,
590583
# so that users can colocate other actors on the same GPUs as vLLM.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from torch.nn.parameter import UninitializedParameter
99

10-
import vllm.envs as envs
1110
from vllm.config import get_current_vllm_config
1211
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
1312
get_tensor_model_parallel_world_size,
@@ -342,14 +341,6 @@ def __init__(
342341
if params_dtype is None:
343342
params_dtype = torch.get_default_dtype()
344343

345-
# For smuggling this layer into the fused moe custom op
346-
compilation_config = get_current_vllm_config().compilation_config
347-
if prefix in compilation_config.static_forward_context:
348-
raise ValueError("Duplicate layer name: {}".format(prefix))
349-
compilation_config.static_forward_context[prefix] = self
350-
self.layer_name = prefix
351-
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP
352-
353344
# Note: here we guard against accessing the TP and DP groups when
354345
# uninitialized (this happens when testing)
355346
self.tp_size = (tp_size if tp_size is not None else
@@ -361,7 +352,21 @@ def __init__(
361352
if self.dp_size == 1 else get_dp_group().rank_in_group)
362353
self.global_num_experts = num_experts
363354

364-
if envs.VLLM_TEST_ENABLE_EP:
355+
# Use expert parallelism instead of tensor parallelism?
356+
vllm_config = get_current_vllm_config()
357+
use_ep = (vllm_config.parallel_config.enable_expert_parallel
358+
and self.tp_size > 1)
359+
360+
# For smuggling this layer into the fused moe custom op
361+
self.use_direct_call = self.dp_size == 1
362+
if self.use_direct_call:
363+
compilation_config = vllm_config.compilation_config
364+
if prefix in compilation_config.static_forward_context:
365+
raise ValueError("Duplicate layer name: {}".format(prefix))
366+
compilation_config.static_forward_context[prefix] = self
367+
self.layer_name = prefix
368+
369+
if use_ep:
365370
# Set TP size to 1 to adjust for EP and adjust EP size and rank
366371
# for DP attention.
367372
self.ep_rank = tp_rank + self.tp_size * self.dp_rank

0 commit comments

Comments
 (0)