Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add initial support for sequence parallelism #1436

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def from_cli_args(cls, args: argparse.Namespace):
)


def load_model(server_args, tp_rank):
def load_model(server_args, tp_rank, sp_rank: int = 0):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

Expand All @@ -130,6 +130,8 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
sp_rank=sp_rank,
sp_size=server_args.sp_size,
nccl_port=28888,
server_args=server_args,
)
Expand Down Expand Up @@ -206,6 +208,8 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
sp_size=model_runner.sp_size,
sp_rank=model_runner.sp_rank,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
Expand All @@ -225,11 +229,12 @@ def correctness_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
Expand Down Expand Up @@ -336,11 +341,12 @@ def latency_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
Expand Down Expand Up @@ -458,16 +464,18 @@ def main(server_args, bench_args):
)

if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
work_func(server_args, bench_args, 0, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
sp_rank = tp_rank % server_args.sp_size
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
sp_rank,
),
)
proc.start()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/parallel_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parallel_state import *
96 changes: 96 additions & 0 deletions python/sglang/srt/layers/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List, Optional

import torch
from vllm.distributed import initialize_model_parallel as vllm_initialize_model_parallel
from vllm.distributed.parallel_state import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_world_group,
init_model_parallel_group,
)

_SP: Optional[GroupCoordinator] = None


def get_sp_group():
assert _SP is not None, "sequence parallel group is not initialized"
return _SP


def init_sequence_parallel_group(
group_ranks: List[List[int]], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
)


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
sequence_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups and sequence parallel groups.

For sequence parallelism, we partition SP groups within a TP group, and assign
gpus with adjacent ranks to the same SP group. For example, with TP size 8
and SP size 2, we have 1 TP group and 4 SP groups:
SP groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
Their KV TP rank:
[ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3]
Given that we replicate KV heads within the same seq parallel group, we also say that
the KV TP size is 4 (8//2), and gpus in each SP group have KV-tp rank from 0 to 3.
"""
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)

num_sequence_parallel_groups: int = world_size // sequence_parallel_size
global _SP
assert _SP is None, "sequence parallel group is already initialized"
group_ranks = []
for i in range(num_sequence_parallel_groups):
ranks = list(
range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
)
group_ranks.append(ranks)
_SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend)

vllm_initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend
)


def sequence_parallel_is_initialized():
return _SP is not None


def get_sequence_parallel_world_size():
return get_sp_group().world_size


def get_sequence_parallel_rank():
return get_sp_group().rank_in_group


def get_sequence_parallel_global_rank():
return get_tensor_model_parallel_rank()


# NOTE: For sequence parallelism, we partition Q tensors along the head dimension.
# But K/V tensors are partitioned along the head dimension in TP and partitioned
# along the sequence dimensions in SP. Therefore, their TP size and rank is adjusted
# accordingly as below.
def get_kv_tensor_model_parallel_world_size():
return get_tensor_model_parallel_world_size() // get_sequence_parallel_world_size()


def get_kv_tensor_model_parallel_rank():
return get_tensor_model_parallel_rank() // get_sequence_parallel_world_size()
Loading