Skip to content

Commit 1599422

Browse files
modify num-of-microbatches to ubatch-size
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
1 parent 7d29508 commit 1599422

File tree

10 files changed

+38
-29
lines changed

10 files changed

+38
-29
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def parse_args():
9393
help=("Enable microbatched execution"),
9494
)
9595
parser.add_argument(
96-
"--num-of-microbatches",
96+
"--ubatch-size",
9797
type=int,
9898
default=2,
9999
help=("Number of microbatches. Requires --enable-dbo to be enabled."),

vllm/config/parallel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ class ParallelConfig:
156156

157157
enable_dbo: bool = False
158158
"""Enable dual batch overlap for the model executor."""
159-
num_of_microbatches: int = 2
160-
"""Number of microbatches. Requires --enable-dbo to be enabled."""
159+
ubatch_size: int = 0
160+
"""Number of ubatch size."""
161161

162162
dbo_decode_token_threshold: int = 32
163163
"""The threshold for dual batch overlap for batches only containing decodes.
@@ -330,6 +330,14 @@ def world_size_across_dp(self) -> int:
330330
"""world_size_across_dp is TPxPPxDP, it is the size of the world
331331
including data parallelism."""
332332
return self.world_size * self.data_parallel_size
333+
334+
@property
335+
def use_ubatching(self) -> bool:
336+
return self.enable_dbo or self.ubatch_size > 1
337+
338+
@property
339+
def num_of_ubatches(self) -> int:
340+
return 2 if self.enable_dbo else self.ubatch_size
333341

334342
def get_next_dp_init_port(self) -> int:
335343
"""

vllm/config/vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ def has_blocked_weights():
851851
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
852852
)
853853

854-
if self.parallel_config.enable_dbo:
854+
if self.parallel_config.use_ubatching:
855855
a2a_backend = self.parallel_config.all2all_backend
856856
assert a2a_backend in [
857857
"deepep_low_latency",

vllm/engine/arg_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class EngineArgs:
409409
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
410410
all2all_backend: str | None = ParallelConfig.all2all_backend
411411
enable_dbo: bool = ParallelConfig.enable_dbo
412-
num_of_microbatches: int = ParallelConfig.num_of_microbatches
412+
ubatch_size: int = ParallelConfig.ubatch_size
413413
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
414414
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
415415
disable_nccl_for_dp_synchronization: bool = (
@@ -830,8 +830,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
830830
)
831831
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
832832
parallel_group.add_argument(
833-
"--num-of-microbatches",
834-
**parallel_kwargs["num_of_microbatches"],
833+
"--ubatch-size",
834+
**parallel_kwargs["ubatch_size"],
835835
)
836836
parallel_group.add_argument(
837837
"--dbo-decode-token-threshold",
@@ -1607,7 +1607,7 @@ def create_engine_config(
16071607
enable_expert_parallel=self.enable_expert_parallel,
16081608
all2all_backend=self.all2all_backend,
16091609
enable_dbo=self.enable_dbo,
1610-
num_of_microbatches=self.num_of_microbatches,
1610+
ubatch_size=self.ubatch_size,
16111611
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
16121612
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
16131613
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
dbo_yield_and_switch_from_comm_to_compute,
2424
dbo_yield_and_switch_from_compute_to_comm,
2525
)
26-
26+
from typing import Any
2727

2828
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
2929
"""
@@ -63,7 +63,7 @@ def __init__(
6363
# The dispatch function returns a handle that the combine function
6464
# requires. Under DBO microbatching we must track one handle per
6565
# micro-batch to avoid races between threads.
66-
self.handles = []
66+
self.handles: list[Any | None] = []
6767

6868
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
6969
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,9 +1502,10 @@ def ensure_dp_chunking_init(self):
15021502

15031503
moe = self.moe_config
15041504

1505-
if self.vllm_config.parallel_config.enable_dbo:
1506-
states_shape = (2, moe.max_num_tokens, self.hidden_size)
1507-
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
1505+
if self.vllm_config.parallel_config.use_ubatching:
1506+
num_of_ubatches = self.vllm_config.parallel_config.num_of_ubatches
1507+
states_shape = (num_of_ubatches, moe.max_num_tokens, self.hidden_size)
1508+
logits_shape = (num_of_ubatches, moe.max_num_tokens, self.logical_num_experts)
15081509
else:
15091510
states_shape = (moe.max_num_tokens, self.hidden_size)
15101511
logits_shape = (moe.max_num_tokens, self.logical_num_experts)

vllm/v1/worker/dp_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _synchronize_dp_ranks(
132132
assert should_attempt_dp_padding == should_dp_pad
133133

134134
# Check conditions for microbatching
135-
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_of_microbatches)
135+
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_of_ubatches)
136136

137137
if should_ubatch and not should_dp_pad:
138138
logger.debug_once(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,7 +2910,7 @@ def execute_model(
29102910

29112911
cascade_attn_prefix_lens = None
29122912
# Disable cascade attention when using microbatching (DBO)
2913-
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
2913+
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
29142914
# Pre-compute cascade attention prefix lengths
29152915
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
29162916
num_scheduled_tokens_np,
@@ -2950,11 +2950,11 @@ def execute_model(
29502950
num_scheduled_tokens_np,
29512951
num_tokens_padded,
29522952
num_reqs_padded,
2953-
self.parallel_config.num_of_microbatches,
2953+
self.parallel_config.num_of_ubatches,
29542954
)
29552955

2956-
logger.info(
2957-
"jcz ubatch_slices: %s, ubatch_slices_padded: %s",
2956+
logger.debug(
2957+
"ubatch_slices: %s, ubatch_slices_padded: %s",
29582958
ubatch_slices,
29592959
ubatch_slices_padded,
29602960
)
@@ -3624,11 +3624,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
36243624
# wrap the model with full cudagraph wrapper if needed.
36253625
cudagraph_mode = self.compilation_config.cudagraph_mode
36263626
assert cudagraph_mode is not None
3627-
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
3627+
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.use_ubatching:
36283628
self.model = CUDAGraphWrapper(
36293629
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
36303630
)
3631-
elif self.parallel_config.enable_dbo:
3631+
elif self.parallel_config.use_ubatching:
36323632
if cudagraph_mode.has_full_cudagraphs():
36333633
self.model = UBatchWrapper(
36343634
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
@@ -3999,10 +3999,10 @@ def _dummy_run(
39993999
num_scheduled_tokens,
40004000
num_tokens_padded,
40014001
num_reqs_padded,
4002-
self.vllm_config.parallel_config.num_of_microbatches,
4002+
self.vllm_config.parallel_config.num_of_ubatches,
40034003
)
4004-
logger.info(
4005-
"jcz ubatch_slices: %s, ubatch_slices_padded: %s",
4004+
logger.debug(
4005+
"ubatch_slices: %s, ubatch_slices_padded: %s",
40064006
ubatch_slices,
40074007
ubatch_slices_padded,
40084008
)
@@ -4529,8 +4529,8 @@ def _capture_cudagraphs(
45294529
# is above the threshold. Otherwise we just capture a non-ubatched
45304530
# version of the graph
45314531
allow_microbatching = (
4532-
self.parallel_config.enable_dbo
4533-
and self.parallel_config.num_of_microbatches > 1
4532+
self.parallel_config.use_ubatching
4533+
and self.parallel_config.num_of_ubatches > 1
45344534
and cudagraph_runtime_mode == CUDAGraphMode.FULL
45354535
and uniform_decode
45364536
and check_ubatch_thresholds(
@@ -4662,8 +4662,8 @@ def initialize_metadata_builders(
46624662
if kv_cache_group_id < len(kernel_block_sizes)
46634663
else None,
46644664
num_metadata_builders=1
4665-
if not self.parallel_config.enable_dbo
4666-
else self.parallel_config.num_of_microbatches,
4665+
if not self.parallel_config.use_ubatching
4666+
else self.parallel_config.num_of_ubatches,
46674667
)
46684668
# Calculate reorder batch threshold (if needed)
46694669
# Note (tdoublep): do this *after* constructing builders,

vllm/v1/worker/gpu_ubatch_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
self.comm_stream = torch.cuda.Stream(device=device)
106106
# Ubatch threads plus the main thread
107107
self.ready_barrier = threading.Barrier(
108-
self.vllm_config.parallel_config.num_of_microbatches + 1
108+
self.vllm_config.parallel_config.num_of_ubatches + 1
109109
)
110110

111111
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}

vllm/v1/worker/ubatch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def is_last_ubatch_empty(
3838
def check_ubatch_thresholds(
3939
config: ParallelConfig, num_tokens: int, uniform_decode: bool
4040
) -> bool:
41-
if not config.enable_dbo:
41+
if not config.use_ubatching:
4242
return False
4343
if uniform_decode:
4444
return num_tokens >= config.dbo_decode_token_threshold

0 commit comments

Comments
 (0)