@@ -2910,7 +2910,7 @@ def execute_model(
29102910
29112911 cascade_attn_prefix_lens = None
29122912 # Disable cascade attention when using microbatching (DBO)
2913- if self .cascade_attn_enabled and not self .parallel_config .enable_dbo :
2913+ if self .cascade_attn_enabled and not self .parallel_config .use_ubatching :
29142914 # Pre-compute cascade attention prefix lengths
29152915 cascade_attn_prefix_lens = self ._compute_cascade_attn_prefix_lens (
29162916 num_scheduled_tokens_np ,
@@ -2950,11 +2950,11 @@ def execute_model(
29502950 num_scheduled_tokens_np ,
29512951 num_tokens_padded ,
29522952 num_reqs_padded ,
2953- self .parallel_config .num_of_microbatches ,
2953+ self .parallel_config .num_of_ubatches ,
29542954 )
29552955
2956- logger .info (
2957- "jcz ubatch_slices: %s, ubatch_slices_padded: %s" ,
2956+ logger .debug (
2957+ "ubatch_slices: %s, ubatch_slices_padded: %s" ,
29582958 ubatch_slices ,
29592959 ubatch_slices_padded ,
29602960 )
@@ -3624,11 +3624,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
36243624 # wrap the model with full cudagraph wrapper if needed.
36253625 cudagraph_mode = self .compilation_config .cudagraph_mode
36263626 assert cudagraph_mode is not None
3627- if cudagraph_mode .has_full_cudagraphs () and not self .parallel_config .enable_dbo :
3627+ if cudagraph_mode .has_full_cudagraphs () and not self .parallel_config .use_ubatching :
36283628 self .model = CUDAGraphWrapper (
36293629 self .model , self .vllm_config , runtime_mode = CUDAGraphMode .FULL
36303630 )
3631- elif self .parallel_config .enable_dbo :
3631+ elif self .parallel_config .use_ubatching :
36323632 if cudagraph_mode .has_full_cudagraphs ():
36333633 self .model = UBatchWrapper (
36343634 self .model , self .vllm_config , CUDAGraphMode .FULL , self .device
@@ -3999,10 +3999,10 @@ def _dummy_run(
39993999 num_scheduled_tokens ,
40004000 num_tokens_padded ,
40014001 num_reqs_padded ,
4002- self .vllm_config .parallel_config .num_of_microbatches ,
4002+ self .vllm_config .parallel_config .num_of_ubatches ,
40034003 )
4004- logger .info (
4005- "jcz ubatch_slices: %s, ubatch_slices_padded: %s" ,
4004+ logger .debug (
4005+ "ubatch_slices: %s, ubatch_slices_padded: %s" ,
40064006 ubatch_slices ,
40074007 ubatch_slices_padded ,
40084008 )
@@ -4529,8 +4529,8 @@ def _capture_cudagraphs(
45294529 # is above the threshold. Otherwise we just capture a non-ubatched
45304530 # version of the graph
45314531 allow_microbatching = (
4532- self .parallel_config .enable_dbo
4533- and self .parallel_config .num_of_microbatches > 1
4532+ self .parallel_config .use_ubatching
4533+ and self .parallel_config .num_of_ubatches > 1
45344534 and cudagraph_runtime_mode == CUDAGraphMode .FULL
45354535 and uniform_decode
45364536 and check_ubatch_thresholds (
@@ -4662,8 +4662,8 @@ def initialize_metadata_builders(
46624662 if kv_cache_group_id < len (kernel_block_sizes )
46634663 else None ,
46644664 num_metadata_builders = 1
4665- if not self .parallel_config .enable_dbo
4666- else self .parallel_config .num_of_microbatches ,
4665+ if not self .parallel_config .use_ubatching
4666+ else self .parallel_config .num_of_ubatches ,
46674667 )
46684668 # Calculate reorder batch threshold (if needed)
46694669 # Note (tdoublep): do this *after* constructing builders,
0 commit comments