Skip to content

Commit

Permalink
Support cuda graph for DP attention (#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Nov 18, 2024
1 parent 11f881d commit 62832bb
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 26 deletions.
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ class ScheduleBatch:

# For DP attention
global_num_tokens: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False

# For processing logprobs
return_logprob: bool = False
Expand Down Expand Up @@ -891,6 +892,13 @@ def prepare_for_idle(self):
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens_sum = 0
self.extend_num_tokens = 0

def prepare_for_decode(self, enable_overlap: bool = False):
Expand Down Expand Up @@ -1032,6 +1040,7 @@ def get_model_worker_batch(self):
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
Expand Down Expand Up @@ -1093,6 +1102,7 @@ class ModelWorkerBatch:

# For DP attention
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool

# For extend
extend_num_tokens: Optional[int]
Expand Down
32 changes: 23 additions & 9 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def watchdog_thread(self):

kill_parent_process()

@torch.inference_mode()
@torch.no_grad()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
Expand Down Expand Up @@ -375,7 +375,7 @@ def event_loop_normal(self):

self.last_batch = batch

@torch.inference_mode()
@torch.no_grad()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
Expand Down Expand Up @@ -411,16 +411,12 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
else:
num_tokens = local_batch.extend_num_tokens

local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
group=self.tp_cpu_group,
)

if local_batch is None and global_num_tokens.max().item() > 0:
Expand All @@ -429,6 +425,24 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()

# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

return local_batch

def get_idle_batch(self):
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group

def get_tp_device_group(self):
return self.model_runner.tp_group.device_group

def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
Expand Down
5 changes: 1 addition & 4 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ def get_pad_input_ids_func(self):
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()

def get_tp_device_group(self):
return self.worker.get_tp_device_group()

def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
Expand All @@ -96,7 +93,7 @@ def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()

@torch.inference_mode()
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
Expand Down
50 changes: 44 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def __init__(self, model_runner: "ModelRunner"):
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size

# Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding:
Expand Down Expand Up @@ -165,6 +167,16 @@ def __init__(self, model_runner: "ModelRunner"):
else:
self.encoder_lens = None

if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)

# Capture
try:
with self.model_capture_mode():
Expand All @@ -190,11 +202,21 @@ def model_capture_mode(self):
self.model_runner.model.capture_mode = False

def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)

# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
Expand Down Expand Up @@ -239,6 +261,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs]

if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
gathered_buffer = None

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
Expand All @@ -265,6 +294,8 @@ def run_once():
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
Expand Down Expand Up @@ -295,7 +326,12 @@ def replay(self, forward_batch: ForwardBatch):
raw_bs = forward_batch.batch_size

# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
Expand All @@ -310,6 +346,8 @@ def replay(self, forward_batch: ForwardBatch):
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False

def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
Expand Down Expand Up @@ -221,6 +222,7 @@ def init_new(
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ def forward_extend(self, forward_batch: ForwardBatch):
)

def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)

return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ def __post_init__(self):
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.disable_cuda_graph = True
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.enable_overlap_schedule = False
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
"Data parallel size is adjusted to be the same as tensor parallel size."
)

if self.enable_overlap_schedule:
Expand Down
4 changes: 2 additions & 2 deletions scripts/playground/reference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sglang.srt.hf_transformers_utils import get_tokenizer


@torch.inference_mode()
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -69,7 +69,7 @@ def normal_text(args):
print(output_str)


@torch.inference_mode()
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
Expand Down

0 comments on commit 62832bb

Please sign in to comment.