Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 30 additions & 50 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -395,30 +394,20 @@ def __call__(self, graph: fx.Graph):
if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None

MiB = 1024 * 1024
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
def flashinfer_max_size(world_size: int, config: VllmConfig) -> int:
"""
Returns the max communication size in bytes for flashinfer
allreduce fusion for the given world size. Falls back to
conservative defaults if the world size is not specified in config.
"""
MiB = 1024 * 1024
max_sizes = {
k: int(v * MiB)
for k, v in config.compilation_config.pass_config.
fi_allreduce_fusion_max_size_mb.items()
}

try:
_FI_MAX_SIZES.update({
int(k): int(float(v) * MiB)
for k, v in
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
})
except Exception as e:
raise ValueError(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
+ str(e)) from e

# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
return max_sizes.get(world_size, MiB // 2)

def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
Expand All @@ -438,14 +427,8 @@ def call_trtllm_fused_allreduce_norm(
scale_out: Optional[torch.Tensor] = None,
scale_factor: Optional[torch.Tensor] = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
num_tokens, _ = allreduce_in.shape
use_flashinfer = num_tokens <= max_token_num
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"
Expand Down Expand Up @@ -561,9 +544,9 @@ def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
fuse_rms_quant: bool = False,
use_fp32_lamport: bool,
max_token_num: int,
fuse_rms_quant: bool,
):
self.rank = rank
self.world_size = world_size
Expand Down Expand Up @@ -1089,24 +1072,21 @@ def __init__(self, config: VllmConfig):
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass")
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES:
logger.warning(
"Flashinfer allreduce fusion is not "
"supported for world size %s",
self.tp_size,
)
return
max_num_token = min(
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num)

max_size = flashinfer_max_size(self.tp_size, config)
element_size = 4 if use_fp32_lamport else 2
max_token_num = (max_size //
(self.hidden_dim * element_size * self.tp_size))
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
max_token_num = min(max_token_num,
config.scheduler_config.max_num_batched_tokens)

self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=max_num_token,
max_token_num=max_token_num,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
Expand All @@ -1118,7 +1098,7 @@ def __init__(self, config: VllmConfig):
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=max_num_token,
max_token_num=max_token_num,
# fuse rms norm static fp8 quant fused op
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
Expand Down
21 changes: 19 additions & 2 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,14 @@ class PassConfig:
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion."""
fi_allreduce_fusion_max_size_mb: dict[int,
float] = field(default_factory=dict)
"""The thresholds of the communicated tensor sizes under which
vLLM should use flashinfer fused allreduce. Specified as a
dictionary mapping each world size to the threshold in MiB
{ <world size>: <max size in MiB> }
Unspecified world sizes will fallback to
{ 2: 64, 4: 1, <everything else>: 0.5 }"""

# TODO(luka) better pass enabling system.

Expand Down Expand Up @@ -460,6 +466,17 @@ def __post_init__(self) -> None:
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL

fi_allreduce_fusion_max_size_mb = {
2: 64,
4: 1,
6: 0.5,
8: 0.5,
}
fi_allreduce_fusion_max_size_mb.update(
self.pass_config.fi_allreduce_fusion_max_size_mb)
self.pass_config.fi_allreduce_fusion_max_size_mb = \
fi_allreduce_fusion_max_size_mb

def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
Expand Down
11 changes: 0 additions & 11 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
import json
import os
import sys
import tempfile
Expand Down Expand Up @@ -1061,16 +1060,6 @@ def get_vllm_port() -> Optional[int]:
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),

# Specifies the thresholds of the communicated tensor sizes under which
# vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format:
# { <world size>: <max size in mb> }
# Unspecified world sizes will fall back to
# { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
lambda: json.loads(os.getenv(
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),

# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
Expand Down