Skip to content

[Distributed] Add enable_expert_parallel arg #14305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import os
Expand Down Expand Up @@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
# Create an LLM.
llm = LLM(model="ibm-research/PowerMoE-3b",
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True)
enforce_eager=True,
enable_expert_parallel=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def verify_with_parallel_config(
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

if envs.VLLM_TEST_ENABLE_EP:
if parallel_config.enable_expert_parallel:
self._verify_with_expert_parallelism()

pipeline_parallel_size = parallel_config.pipeline_parallel_size
Expand Down Expand Up @@ -1334,6 +1334,7 @@ class ParallelConfig:
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.

# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
Expand Down Expand Up @@ -439,6 +440,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
help='Use expert parallelism instead of tensor parallelism '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more comment to say that EP is the multiplication of DP and TP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think we should leave this as-is and add this to the --data-parallel argument added in ##13923. We parallelize across all DP*TP ranks for the MoE layers regardless of if we are using EP.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good.

'for MoE layers.')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
Expand Down Expand Up @@ -1199,6 +1205,7 @@ def create_engine_config(self,
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
Expand Down
7 changes: 0 additions & 7 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_TEST_ENABLE_EP: bool = False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
Expand Down Expand Up @@ -578,12 +577,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
),

# If set, vLLM will use the experimental expert parallel implementation on
# the FusedMoE layer, using tensor parallelism size as expert parallelism
# size.
"VLLM_TEST_ENABLE_EP":
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),

# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
Expand Down
25 changes: 15 additions & 10 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from torch.nn.parameter import UninitializedParameter

import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -342,14 +341,6 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()

# For smuggling this layer into the fused moe custom op
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP

# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self.tp_size = (tp_size if tp_size is not None else
Expand All @@ -361,7 +352,21 @@ def __init__(
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.global_num_experts = num_experts

if envs.VLLM_TEST_ENABLE_EP:
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size > 1)

# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
if self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix

if use_ep:
# Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
Expand Down