Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def __init__(
pynccl_use_current_stream: bool = False,
torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
active_ranks: Optional[torch.Tensor] = None,
active_ranks_cpu: Optional[torch.Tensor] = None,
):
# Set group info
group_name = group_name or "anonymous"
Expand All @@ -282,14 +284,33 @@ def __init__(
self.local_size = get_int_env_var("LOCAL_SIZE", 0)

for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a cpu_group to allow direct coordination between processes through
# the CPU. The backend is chosen based on `torch_distributed_backend`
if "mooncake" in torch_distributed_backend:
cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu")
from mooncake.ep import MooncakeBackendOptions

device_group = torch.distributed.new_group(
ranks,
backend="mooncake",
pg_options=(
MooncakeBackendOptions(active_ranks)
if active_ranks is not None
else None
),
)
cpu_group = torch.distributed.new_group(
ranks,
backend="mooncake-cpu",
pg_options=(
MooncakeBackendOptions(active_ranks_cpu)
if active_ranks_cpu is not None
else None
),
)
else:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination
# between processes through the CPU.
cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
Expand Down Expand Up @@ -1361,6 +1382,8 @@ def init_model_parallel_group(
pynccl_use_current_stream: bool = True,
use_torch_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None,
active_ranks: Optional[torch.Tensor] = None,
active_ranks_cpu: Optional[torch.Tensor] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
Expand All @@ -1372,7 +1395,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not (_is_npu or _is_xpu),
use_pynccl=not (_is_npu or _is_xpu or backend == "mooncake"),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce,
Expand All @@ -1383,10 +1406,23 @@ def init_model_parallel_group(
group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
torch_compile=torch_compile,
active_ranks=active_ranks,
active_ranks_cpu=active_ranks_cpu,
)


_TP: Optional[GroupCoordinator] = None
_TP_ACTIVE_RANKS: Optional[torch.Tensor] = None
_TP_ACTIVE_RANKS_CPU: Optional[torch.Tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking, would it be more reasonable to rename the elastic ep module to elstic and group together concepts like _TP_ACTIVE_RANKS and _TP_ACTIVE_RANKS_CPU as much as possible?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it less intrusive to the original logic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable



def get_tp_active_ranks():
return _TP_ACTIVE_RANKS


def get_tp_active_ranks_cpu():
return _TP_ACTIVE_RANKS_CPU


# duplicate GroupCoordinator for prefill in PD-Multiplexing
_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
Expand Down Expand Up @@ -1600,6 +1636,18 @@ def initialize_model_parallel(
)
group_ranks.append(ranks)

global _TP_ACTIVE_RANKS
global _TP_ACTIVE_RANKS_CPU
if backend == "mooncake":
_TP_ACTIVE_RANKS = torch.ones(
(tensor_model_parallel_size,), dtype=torch.int32, device="cuda"
)
_TP_ACTIVE_RANKS_CPU = torch.ones(
(tensor_model_parallel_size,), dtype=torch.int32, device="cpu"
)
else:
_TP_ACTIVE_RANKS = None
_TP_ACTIVE_RANKS_CPU = None
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
Expand All @@ -1611,6 +1659,8 @@ def initialize_model_parallel(
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile,
active_ranks=_TP_ACTIVE_RANKS,
active_ranks_cpu=_TP_ACTIVE_RANKS_CPU,
)

if duplicate_tp_group:
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,13 @@ def combine_a(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional[CombineOverlapArgs] = None,
):
hidden_states, event, hook = self._combine_core(
hidden_states,
topk_ids,
topk_weights,
)
return hidden_states, event, hook, overlap_args
return hidden_states, event, hook

def combine_b(self, hidden_states, event, hook):
hook() if self.return_recv_hook else event.current_stream_wait()
Expand Down Expand Up @@ -368,7 +367,6 @@ def combine_a(
hidden_states=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
overlap_args=self.overlap_args,
)
self._combine_intermediate_state = inner_state

Expand Down
28 changes: 23 additions & 5 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
ActiveRanksOutput,
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -161,6 +162,7 @@ def __init__(
# Launch data parallel workers
self.scheduler_procs = []
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
self.status: List[int] = [1] * server_args.dp_size

if server_args.enable_dp_attention:
self.launch_dp_attention_schedulers(server_args, port_args)
Expand All @@ -183,6 +185,9 @@ def send_control_message(self, obj):
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)

def update_active_ranks(self, ranks: ActiveRanksOutput):
self.status = ranks.status

def dispatching_with_trace(self, req: Req):
if self.server_args.enable_trace:
trace_set_proc_propagate_context(req.rid, req.trace_context)
Expand All @@ -201,6 +206,7 @@ def init_dispatcher(self):
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
(ActiveRanksOutput, self.update_active_ranks),
]
)
self._request_dispatcher.add_fallback_fn(self.send_control_message)
Expand Down Expand Up @@ -473,15 +479,27 @@ def round_robin_scheduler(self, req: Req):
return

if self.server_args.disaggregation_mode == "null":
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
while True:
if self.status[self.round_robin_counter] == 1:
logger.info(f"Choose worker {self.round_robin_counter}")
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
break
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
else:
assert (
req.bootstrap_room is not None
), "req.bootstrap_room should not be None. Do not send requests directly to prefill or decode instances, but send to the router instead."
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
id = req.bootstrap_room % len(self.workers)
while True:
if self.status[id] == 1:
self.workers[id].send_pyobj(req)
break
id = (id + 1) % len(self.workers)

def shortest_queue_scheduler(self, req):
if self.maybe_external_dp_rank_routing(req):
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,11 @@ def __post_init__(self):
self.rid = ""


@dataclass
class ActiveRanksOutput(BaseReq):
status: List[int]


@dataclass
class GetInternalStateReq(BaseReq):
pass
Expand Down
20 changes: 19 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
TransferBackend,
prepare_abort,
)
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.distributed import (
get_pp_group,
get_tp_active_ranks,
get_tp_active_ranks_cpu,
get_world_group,
)
from sglang.srt.dllm.config import DllmConfig
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
Expand All @@ -68,6 +73,7 @@
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.io_struct import (
AbortReq,
ActiveRanksOutput,
BaseBatchReq,
BaseReq,
BatchTokenizedEmbeddingReqInput,
Expand Down Expand Up @@ -157,6 +163,7 @@
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser
Expand Down Expand Up @@ -2161,6 +2168,17 @@ def run_batch(
batch_result.extend_logprob_start_len_per_req = (
extend_logprob_start_len_per_req
)
if (
self.server_args.enable_dp_attention
and self.server_args.elastic_ep_backend == "mooncake"
):
# Get the tensors indicating rank activeness
tp_active_ranks = get_tp_active_ranks().detach().cpu().numpy()
tp_active_ranks_cpu = get_tp_active_ranks_cpu().detach().numpy()
tp_active_ranks &= tp_active_ranks_cpu
self.send_to_tokenizer.send_output(
ActiveRanksOutput(status=tp_active_ranks.tolist())
)
ret = batch_result
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/scheduler_dp_attn_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch

from sglang.srt.batch_overlap.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.distributed.parallel_state import get_tp_active_ranks, get_tp_active_ranks_cpu
from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils.common import require_mlp_tp_gather

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,6 +63,15 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup):
local_info_tensor,
group=group,
)
if device == "cpu":
tp_active_ranks = get_tp_active_ranks_cpu()
else:
tp_active_ranks = get_tp_active_ranks()
global_info_tensor.view(-1, 6)[tp_active_ranks == 0, :] = torch.tensor(
[0, 1, 0, 0, 1, ForwardMode.IDLE.value],
device=global_info_tensor.device,
dtype=global_info_tensor.dtype,
)

tp0_info = global_info_tensor[:, 0, :]
self.tp0_info = tp0_info
Expand Down Expand Up @@ -142,6 +153,7 @@ def prepare_mlp_sync_batch_raw(
if len(offload_tags) == 0 and disable_overlap_schedule:
group = tp_group.device_group
device = tp_group.device
torch.distributed.barrier(group=tp_group.cpu_group)
else:
group = tp_group.cpu_group
device = "cpu"
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
ActiveRanksOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
Expand Down Expand Up @@ -407,6 +408,7 @@ def __init__(
(FreezeGCReq, lambda x: None),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(HealthCheckOutput, lambda x: None),
(ActiveRanksOutput, self.update_active_ranks),
]
)
self.init_communicators(server_args)
Expand Down Expand Up @@ -2013,6 +2015,9 @@ def _handle_abort_req(self, recv_obj: AbortReq):
state.out_list.append(out)
state.event.set()

def update_active_ranks(self, ranks: ActiveRanksOutput):
self.send_to_scheduler.send_pyobj(ranks)

def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,27 @@ def forward(
split_forward_count,
)

elastic_ep_state = ElasticEPStateManager.instance()
if (
elastic_ep_state is not None
and not elastic_ep_state.is_active_equal_last()
):
elastic_ep_state.snapshot_active_to_last()
elastic_ep_state.sync_active_to_cpu()
logging.info("EPLB due to rank faults")
gen = self.eplb_manager.rebalance()
while True:
try:
next(gen)
except StopIteration:
break
output = self._forward_raw(
forward_batch,
skip_attn_backend_init,
pp_proxy_tensors,
reinit_attn_backend,
split_forward_count,
)
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end()

Expand Down
Loading
Loading