Skip to content

[Misc] Turn MOE_DP_CHUNK_SIZE into an env var #19506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
Expand Down Expand Up @@ -773,6 +774,14 @@ def get_vllm_port() -> Optional[int]:
"VLLM_DP_MASTER_PORT":
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),

# In the context of executing MoE models with Data-Parallel, Expert-Parallel
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
# dictates the quantum of tokens that can be dispatched from a DP
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
# units.
"VLLM_MOE_DP_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),

# Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)

# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256


@dataclass
class FusedMoEParallelConfig:
Expand Down Expand Up @@ -218,7 +214,12 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128

max_num_tokens: int = MOE_DP_CHUNK_SIZE
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE

def __post_init__(self):
if self.dp_size > 1:
logger.debug("Using MOEConfig::max_num_tokens=%d",
self.max_num_tokens)

@property
def tp_size(self):
Expand Down Expand Up @@ -913,7 +914,7 @@ def __init__(
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
self.quant_config = quant_config
Expand Down Expand Up @@ -952,12 +953,12 @@ def __init__(
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())

self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())

Expand Down